mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-16 03:35:06 +00:00
Added support for streaming tool calls (#1184)
* added support for streaming tool calls * bug fix: removed tmp assert --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
e7cdae63fd
commit
56aed2d3d1
@ -2,21 +2,36 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Optional, Dict, Callable, Union
|
from typing import Any, List, Optional, Dict, Callable, Tuple, Union
|
||||||
import logging
|
import logging
|
||||||
import inspect
|
import inspect
|
||||||
from flaml.automl.logger import logger_formatter
|
from flaml.automl.logger import logger_formatter
|
||||||
from pydantic import ValidationError
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from autogen.oai import completion
|
||||||
|
|
||||||
from autogen.oai.openai_utils import get_key, OAI_PRICE1K
|
from autogen.oai.openai_utils import get_key, OAI_PRICE1K
|
||||||
from autogen.token_count_utils import count_token
|
from autogen.token_count_utils import count_token
|
||||||
|
from autogen._pydantic import model_dump
|
||||||
|
|
||||||
TOOL_ENABLED = False
|
TOOL_ENABLED = False
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
except ImportError:
|
||||||
|
ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
|
||||||
|
OpenAI = object
|
||||||
|
else:
|
||||||
|
# raises exception if openai>=1 is installed and something is wrong with imports
|
||||||
from openai import OpenAI, APIError, __version__ as OPENAIVERSION
|
from openai import OpenAI, APIError, __version__ as OPENAIVERSION
|
||||||
|
from openai.resources import Completions
|
||||||
from openai.types.chat import ChatCompletion
|
from openai.types.chat import ChatCompletion
|
||||||
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
|
||||||
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
|
ChoiceDeltaToolCall,
|
||||||
|
ChoiceDeltaToolCallFunction,
|
||||||
|
ChoiceDeltaFunctionCall,
|
||||||
|
)
|
||||||
from openai.types.completion import Completion
|
from openai.types.completion import Completion
|
||||||
from openai.types.completion_usage import CompletionUsage
|
from openai.types.completion_usage import CompletionUsage
|
||||||
import diskcache
|
import diskcache
|
||||||
@ -24,9 +39,7 @@ try:
|
|||||||
if openai.__version__ >= "1.1.0":
|
if openai.__version__ >= "1.1.0":
|
||||||
TOOL_ENABLED = True
|
TOOL_ENABLED = True
|
||||||
ERROR = None
|
ERROR = None
|
||||||
except ImportError:
|
|
||||||
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
|
|
||||||
OpenAI = object
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
# Add the console handler.
|
# Add the console handler.
|
||||||
@ -41,10 +54,10 @@ class OpenAIWrapper:
|
|||||||
cache_path_root: str = ".cache"
|
cache_path_root: str = ".cache"
|
||||||
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
|
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
|
||||||
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
|
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
|
||||||
total_usage_summary: Dict = None
|
total_usage_summary: Optional[Dict[str, Any]] = None
|
||||||
actual_usage_summary: Dict = None
|
actual_usage_summary: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
def __init__(self, *, config_list: List[Dict] = None, **base_config):
|
def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
config_list: a list of config dicts to override the base_config.
|
config_list: a list of config dicts to override the base_config.
|
||||||
@ -81,7 +94,9 @@ class OpenAIWrapper:
|
|||||||
logger.warning("openai client was provided with an empty config_list, which may not be intended.")
|
logger.warning("openai client was provided with an empty config_list, which may not be intended.")
|
||||||
if config_list:
|
if config_list:
|
||||||
config_list = [config.copy() for config in config_list] # make a copy before modifying
|
config_list = [config.copy() for config in config_list] # make a copy before modifying
|
||||||
self._clients = [self._client(config, openai_config) for config in config_list] # could modify the config
|
self._clients: List[OpenAI] = [
|
||||||
|
self._client(config, openai_config) for config in config_list
|
||||||
|
] # could modify the config
|
||||||
self._config_list = [
|
self._config_list = [
|
||||||
{**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}}
|
{**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}}
|
||||||
for config in config_list
|
for config in config_list
|
||||||
@ -90,7 +105,9 @@ class OpenAIWrapper:
|
|||||||
self._clients = [self._client(extra_kwargs, openai_config)]
|
self._clients = [self._client(extra_kwargs, openai_config)]
|
||||||
self._config_list = [extra_kwargs]
|
self._config_list = [extra_kwargs]
|
||||||
|
|
||||||
def _process_for_azure(self, config: Dict, extra_kwargs: Dict, segment: str = "default"):
|
def _process_for_azure(
|
||||||
|
self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default"
|
||||||
|
) -> None:
|
||||||
# deal with api_version
|
# deal with api_version
|
||||||
query_segment = f"{segment}_query"
|
query_segment = f"{segment}_query"
|
||||||
headers_segment = f"{segment}_headers"
|
headers_segment = f"{segment}_headers"
|
||||||
@ -123,20 +140,20 @@ class OpenAIWrapper:
|
|||||||
if not base_url.endswith(suffix):
|
if not base_url.endswith(suffix):
|
||||||
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix
|
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix
|
||||||
|
|
||||||
def _separate_openai_config(self, config):
|
def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
"""Separate the config into openai_config and extra_kwargs."""
|
"""Separate the config into openai_config and extra_kwargs."""
|
||||||
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
|
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
|
||||||
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
|
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
|
||||||
self._process_for_azure(openai_config, extra_kwargs)
|
self._process_for_azure(openai_config, extra_kwargs)
|
||||||
return openai_config, extra_kwargs
|
return openai_config, extra_kwargs
|
||||||
|
|
||||||
def _separate_create_config(self, config):
|
def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
"""Separate the config into create_config and extra_kwargs."""
|
"""Separate the config into create_config and extra_kwargs."""
|
||||||
create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
|
create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
|
||||||
extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
|
extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
|
||||||
return create_config, extra_kwargs
|
return create_config, extra_kwargs
|
||||||
|
|
||||||
def _client(self, config, openai_config):
|
def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI:
|
||||||
"""Create a client with the given config to override openai_config,
|
"""Create a client with the given config to override openai_config,
|
||||||
after removing extra kwargs.
|
after removing extra kwargs.
|
||||||
"""
|
"""
|
||||||
@ -148,21 +165,21 @@ class OpenAIWrapper:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def instantiate(
|
def instantiate(
|
||||||
cls,
|
cls,
|
||||||
template: str | Callable | None,
|
template: Optional[Union[str, Callable[[Dict[str, Any]], str]]],
|
||||||
context: Optional[Dict] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
allow_format_str_template: Optional[bool] = False,
|
allow_format_str_template: Optional[bool] = False,
|
||||||
):
|
) -> Optional[str]:
|
||||||
if not context or template is None:
|
if not context or template is None:
|
||||||
return template
|
return template # type: ignore [return-value]
|
||||||
if isinstance(template, str):
|
if isinstance(template, str):
|
||||||
return template.format(**context) if allow_format_str_template else template
|
return template.format(**context) if allow_format_str_template else template
|
||||||
return template(context)
|
return template(context)
|
||||||
|
|
||||||
def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict:
|
def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Prime the create_config with additional_kwargs."""
|
"""Prime the create_config with additional_kwargs."""
|
||||||
# Validate the config
|
# Validate the config
|
||||||
prompt = create_config.get("prompt")
|
prompt: Optional[str] = create_config.get("prompt")
|
||||||
messages = create_config.get("messages")
|
messages: Optional[List[Dict[str, Any]]] = create_config.get("messages")
|
||||||
if (prompt is None) == (messages is None):
|
if (prompt is None) == (messages is None):
|
||||||
raise ValueError("Either prompt or messages should be in create config but not both.")
|
raise ValueError("Either prompt or messages should be in create config but not both.")
|
||||||
context = extra_kwargs.get("context")
|
context = extra_kwargs.get("context")
|
||||||
@ -185,11 +202,11 @@ class OpenAIWrapper:
|
|||||||
}
|
}
|
||||||
if m.get("content")
|
if m.get("content")
|
||||||
else m
|
else m
|
||||||
for m in messages
|
for m in messages # type: ignore [union-attr]
|
||||||
]
|
]
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def create(self, **config):
|
def create(self, **config: Any) -> ChatCompletion:
|
||||||
"""Make a completion for a given config using openai's clients.
|
"""Make a completion for a given config using openai's clients.
|
||||||
Besides the kwargs allowed in openai's client, we allow the following additional kwargs.
|
Besides the kwargs allowed in openai's client, we allow the following additional kwargs.
|
||||||
The config in each client will be overridden by the config.
|
The config in each client will be overridden by the config.
|
||||||
@ -239,11 +256,11 @@ class OpenAIWrapper:
|
|||||||
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
|
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
|
||||||
# Try to get the response from cache
|
# Try to get the response from cache
|
||||||
key = get_key(params)
|
key = get_key(params)
|
||||||
response = cache.get(key, None)
|
response: ChatCompletion = cache.get(key, None)
|
||||||
|
|
||||||
if response is not None:
|
if response is not None:
|
||||||
try:
|
try:
|
||||||
response.cost
|
response.cost # type: ignore [attr-defined]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# update attribute if cost is not calculated
|
# update attribute if cost is not calculated
|
||||||
response.cost = self.cost(response)
|
response.cost = self.cost(response)
|
||||||
@ -264,7 +281,7 @@ class OpenAIWrapper:
|
|||||||
if error_code == "content_filter":
|
if error_code == "content_filter":
|
||||||
# raise the error for content_filter
|
# raise the error for content_filter
|
||||||
raise
|
raise
|
||||||
logger.debug(f"config {i} failed", exc_info=1)
|
logger.debug(f"config {i} failed", exc_info=True)
|
||||||
if i == last:
|
if i == last:
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -284,9 +301,129 @@ class OpenAIWrapper:
|
|||||||
response.pass_filter = pass_filter
|
response.pass_filter = pass_filter
|
||||||
return response
|
return response
|
||||||
continue # filter is not passed; try the next config
|
continue # filter is not passed; try the next config
|
||||||
|
raise RuntimeError("Should not reach here.")
|
||||||
|
|
||||||
def _completions_create(self, client, params):
|
@staticmethod
|
||||||
completions = client.chat.completions if "messages" in params else client.completions
|
def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
|
||||||
|
"""Update the dict from the chunk.
|
||||||
|
|
||||||
|
Reads `chunk.field` and if present updates `d[field]` accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk: The chunk.
|
||||||
|
d: The dict to be updated in place.
|
||||||
|
field: The field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated dict.
|
||||||
|
|
||||||
|
"""
|
||||||
|
completion_tokens = 0
|
||||||
|
assert isinstance(d, dict), d
|
||||||
|
if hasattr(chunk, field) and getattr(chunk, field) is not None:
|
||||||
|
new_value = getattr(chunk, field)
|
||||||
|
if isinstance(new_value, list) or isinstance(new_value, dict):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Field {field} is a list or dict, which is currently not supported. "
|
||||||
|
"Only string and numbers are supported."
|
||||||
|
)
|
||||||
|
if field not in d:
|
||||||
|
d[field] = ""
|
||||||
|
if isinstance(new_value, str):
|
||||||
|
d[field] += getattr(chunk, field)
|
||||||
|
else:
|
||||||
|
d[field] = new_value
|
||||||
|
completion_tokens = 1
|
||||||
|
|
||||||
|
return completion_tokens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_function_call_from_chunk(
|
||||||
|
function_call_chunk: Union[ChoiceDeltaToolCallFunction, ChoiceDeltaFunctionCall],
|
||||||
|
full_function_call: Optional[Dict[str, Any]],
|
||||||
|
completion_tokens: int,
|
||||||
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
|
"""Update the function call from the chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function_call_chunk: The function call chunk.
|
||||||
|
full_function_call: The full function call.
|
||||||
|
completion_tokens: The number of completion tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated full function call and the updated number of completion tokens.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Handle function call
|
||||||
|
if function_call_chunk:
|
||||||
|
if full_function_call is None:
|
||||||
|
full_function_call = {}
|
||||||
|
for field in ["name", "arguments"]:
|
||||||
|
completion_tokens += OpenAIWrapper._update_dict_from_chunk(
|
||||||
|
function_call_chunk, full_function_call, field
|
||||||
|
)
|
||||||
|
|
||||||
|
if full_function_call:
|
||||||
|
return full_function_call, completion_tokens
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Function call is not found, this should not happen.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_tool_calls_from_chunk(
|
||||||
|
tool_calls_chunk: ChoiceDeltaToolCall,
|
||||||
|
full_tool_call: Optional[Dict[str, Any]],
|
||||||
|
completion_tokens: int,
|
||||||
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
|
"""Update the tool call from the chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call_chunk: The tool call chunk.
|
||||||
|
full_tool_call: The full tool call.
|
||||||
|
completion_tokens: The number of completion tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated full tool call and the updated number of completion tokens.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# future proofing for when tool calls other than function calls are supported
|
||||||
|
if tool_calls_chunk.type and tool_calls_chunk.type != "function":
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Tool call type {tool_calls_chunk.type} is currently not supported. "
|
||||||
|
"Only function calls are supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle tool call
|
||||||
|
assert full_tool_call is None or isinstance(full_tool_call, dict), full_tool_call
|
||||||
|
if tool_calls_chunk:
|
||||||
|
if full_tool_call is None:
|
||||||
|
full_tool_call = {}
|
||||||
|
for field in ["index", "id", "type"]:
|
||||||
|
completion_tokens += OpenAIWrapper._update_dict_from_chunk(tool_calls_chunk, full_tool_call, field)
|
||||||
|
|
||||||
|
if hasattr(tool_calls_chunk, "function") and tool_calls_chunk.function:
|
||||||
|
if "function" not in full_tool_call:
|
||||||
|
full_tool_call["function"] = None
|
||||||
|
|
||||||
|
full_tool_call["function"], completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
|
||||||
|
tool_calls_chunk.function, full_tool_call["function"], completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
if full_tool_call:
|
||||||
|
return full_tool_call, completion_tokens
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Tool call is not found, this should not happen.")
|
||||||
|
|
||||||
|
def _completions_create(self, client: OpenAI, params: Dict[str, Any]) -> ChatCompletion:
|
||||||
|
"""Create a completion for a given config using openai's client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: The openai client.
|
||||||
|
params: The params for the completion.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The completion.
|
||||||
|
"""
|
||||||
|
completions: Completions = client.chat.completions if "messages" in params else client.completions # type: ignore [attr-defined]
|
||||||
# If streaming is enabled and has messages, then iterate over the chunks of the response.
|
# If streaming is enabled and has messages, then iterate over the chunks of the response.
|
||||||
if params.get("stream", False) and "messages" in params:
|
if params.get("stream", False) and "messages" in params:
|
||||||
response_contents = [""] * params.get("n", 1)
|
response_contents = [""] * params.get("n", 1)
|
||||||
@ -297,31 +434,52 @@ class OpenAIWrapper:
|
|||||||
print("\033[32m", end="")
|
print("\033[32m", end="")
|
||||||
|
|
||||||
# Prepare for potential function call
|
# Prepare for potential function call
|
||||||
full_function_call = None
|
full_function_call: Optional[Dict[str, Any]] = None
|
||||||
|
full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None
|
||||||
|
|
||||||
# Send the chat completion request to OpenAI's API and process the response in chunks
|
# Send the chat completion request to OpenAI's API and process the response in chunks
|
||||||
for chunk in completions.create(**params):
|
for chunk in completions.create(**params):
|
||||||
if chunk.choices:
|
if chunk.choices:
|
||||||
for choice in chunk.choices:
|
for choice in chunk.choices:
|
||||||
content = choice.delta.content
|
content = choice.delta.content
|
||||||
function_call_chunk = choice.delta.function_call
|
tool_calls_chunks = choice.delta.tool_calls
|
||||||
finish_reasons[choice.index] = choice.finish_reason
|
finish_reasons[choice.index] = choice.finish_reason
|
||||||
|
|
||||||
|
# todo: remove this after function calls are removed from the API
|
||||||
|
# the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail
|
||||||
|
# begin block
|
||||||
|
function_call_chunk = (
|
||||||
|
choice.delta.function_call if hasattr(choice.delta, "function_call") else None
|
||||||
|
)
|
||||||
# Handle function call
|
# Handle function call
|
||||||
if function_call_chunk:
|
if function_call_chunk:
|
||||||
if hasattr(function_call_chunk, "name") and function_call_chunk.name:
|
# Handle function call
|
||||||
if full_function_call is None:
|
if function_call_chunk:
|
||||||
full_function_call = {"name": "", "arguments": ""}
|
full_function_call, completion_tokens = self._update_function_call_from_chunk(
|
||||||
full_function_call["name"] += function_call_chunk.name
|
function_call_chunk, full_function_call, completion_tokens
|
||||||
completion_tokens += 1
|
)
|
||||||
if hasattr(function_call_chunk, "arguments") and function_call_chunk.arguments:
|
|
||||||
full_function_call["arguments"] += function_call_chunk.arguments
|
|
||||||
completion_tokens += 1
|
|
||||||
if choice.finish_reason == "function_call":
|
|
||||||
# Need something here? I don't think so.
|
|
||||||
pass
|
|
||||||
if not content:
|
if not content:
|
||||||
continue
|
continue
|
||||||
# End handle function call
|
# end block
|
||||||
|
|
||||||
|
# Handle tool calls
|
||||||
|
if tool_calls_chunks:
|
||||||
|
for tool_calls_chunk in tool_calls_chunks:
|
||||||
|
# the current tool call to be reconstructed
|
||||||
|
ix = tool_calls_chunk.index
|
||||||
|
if full_tool_calls is None:
|
||||||
|
full_tool_calls = []
|
||||||
|
if ix >= len(full_tool_calls):
|
||||||
|
# in case ix is not sequential
|
||||||
|
full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1)
|
||||||
|
|
||||||
|
full_tool_calls[ix], completion_tokens = self._update_tool_calls_from_chunk(
|
||||||
|
tool_calls_chunk, full_tool_calls[ix], completion_tokens
|
||||||
|
)
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# End handle tool calls
|
||||||
|
|
||||||
# If content is present, print it to the terminal and update response variables
|
# If content is present, print it to the terminal and update response variables
|
||||||
if content is not None:
|
if content is not None:
|
||||||
@ -329,7 +487,8 @@ class OpenAIWrapper:
|
|||||||
response_contents[choice.index] += content
|
response_contents[choice.index] += content
|
||||||
completion_tokens += 1
|
completion_tokens += 1
|
||||||
else:
|
else:
|
||||||
print()
|
# print()
|
||||||
|
pass
|
||||||
|
|
||||||
# Reset the terminal text color
|
# Reset the terminal text color
|
||||||
print("\033[0m\n")
|
print("\033[0m\n")
|
||||||
@ -356,17 +515,23 @@ class OpenAIWrapper:
|
|||||||
index=i,
|
index=i,
|
||||||
finish_reason=finish_reasons[i],
|
finish_reason=finish_reasons[i],
|
||||||
message=ChatCompletionMessage(
|
message=ChatCompletionMessage(
|
||||||
role="assistant", content=response_contents[i], function_call=full_function_call
|
role="assistant",
|
||||||
|
content=response_contents[i],
|
||||||
|
function_call=full_function_call,
|
||||||
|
tool_calls=full_tool_calls,
|
||||||
),
|
),
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# OpenAI versions below 1.5.0
|
# OpenAI versions below 1.5.0
|
||||||
choice = Choice(
|
choice = Choice( # type: ignore [call-arg]
|
||||||
index=i,
|
index=i,
|
||||||
finish_reason=finish_reasons[i],
|
finish_reason=finish_reasons[i],
|
||||||
message=ChatCompletionMessage(
|
message=ChatCompletionMessage(
|
||||||
role="assistant", content=response_contents[i], function_call=full_function_call
|
role="assistant",
|
||||||
|
content=response_contents[i],
|
||||||
|
function_call=full_function_call,
|
||||||
|
tool_calls=full_tool_calls,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -379,7 +544,7 @@ class OpenAIWrapper:
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None:
|
def _update_usage_summary(self, response: Union[ChatCompletion, Completion], use_cache: bool) -> None:
|
||||||
"""Update the usage summary.
|
"""Update the usage summary.
|
||||||
|
|
||||||
Usage is calculated no matter filter is passed or not.
|
Usage is calculated no matter filter is passed or not.
|
||||||
@ -391,17 +556,17 @@ class OpenAIWrapper:
|
|||||||
usage.completion_tokens = 0 if usage.completion_tokens is None else usage.completion_tokens
|
usage.completion_tokens = 0 if usage.completion_tokens is None else usage.completion_tokens
|
||||||
usage.total_tokens = 0 if usage.total_tokens is None else usage.total_tokens
|
usage.total_tokens = 0 if usage.total_tokens is None else usage.total_tokens
|
||||||
except (AttributeError, AssertionError):
|
except (AttributeError, AssertionError):
|
||||||
logger.debug("Usage attribute is not found in the response.", exc_info=1)
|
logger.debug("Usage attribute is not found in the response.", exc_info=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
def update_usage(usage_summary):
|
def update_usage(usage_summary: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
if usage_summary is None:
|
if usage_summary is None:
|
||||||
usage_summary = {"total_cost": response.cost}
|
usage_summary = {"total_cost": response.cost} # type: ignore [union-attr]
|
||||||
else:
|
else:
|
||||||
usage_summary["total_cost"] += response.cost
|
usage_summary["total_cost"] += response.cost # type: ignore [union-attr]
|
||||||
|
|
||||||
usage_summary[response.model] = {
|
usage_summary[response.model] = {
|
||||||
"cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost,
|
"cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost, # type: ignore [union-attr]
|
||||||
"prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0) + usage.prompt_tokens,
|
"prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0) + usage.prompt_tokens,
|
||||||
"completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0)
|
"completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0)
|
||||||
+ usage.completion_tokens,
|
+ usage.completion_tokens,
|
||||||
@ -416,7 +581,7 @@ class OpenAIWrapper:
|
|||||||
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
|
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
|
||||||
"""Print the usage summary."""
|
"""Print the usage summary."""
|
||||||
|
|
||||||
def print_usage(usage_summary, usage_type="total"):
|
def print_usage(usage_summary: Optional[Dict[str, Any]], usage_type: str = "total") -> None:
|
||||||
word_from_type = "including" if usage_type == "total" else "excluding"
|
word_from_type = "including" if usage_type == "total" else "excluding"
|
||||||
if usage_summary is None:
|
if usage_summary is None:
|
||||||
print("No actual cost incurred (all completions are using cache).", flush=True)
|
print("No actual cost incurred (all completions are using cache).", flush=True)
|
||||||
@ -475,20 +640,20 @@ class OpenAIWrapper:
|
|||||||
model = response.model
|
model = response.model
|
||||||
if model not in OAI_PRICE1K:
|
if model not in OAI_PRICE1K:
|
||||||
# TODO: add logging to warn that the model is not found
|
# TODO: add logging to warn that the model is not found
|
||||||
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=1)
|
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
n_input_tokens = response.usage.prompt_tokens
|
n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr]
|
||||||
n_output_tokens = response.usage.completion_tokens
|
n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr]
|
||||||
tmp_price1K = OAI_PRICE1K[model]
|
tmp_price1K = OAI_PRICE1K[model]
|
||||||
# First value is input token rate, second value is output token rate
|
# First value is input token rate, second value is output token rate
|
||||||
if isinstance(tmp_price1K, tuple):
|
if isinstance(tmp_price1K, tuple):
|
||||||
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
|
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return]
|
||||||
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
|
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_text_or_completion_object(
|
def extract_text_or_completion_object(
|
||||||
cls, response: ChatCompletion | Completion
|
cls, response: Union[ChatCompletion, Completion]
|
||||||
) -> Union[List[str], List[ChatCompletionMessage]]:
|
) -> Union[List[str], List[ChatCompletionMessage]]:
|
||||||
"""Extract the text or ChatCompletion objects from a completion or chat response.
|
"""Extract the text or ChatCompletion objects from a completion or chat response.
|
||||||
|
|
||||||
@ -500,18 +665,18 @@ class OpenAIWrapper:
|
|||||||
"""
|
"""
|
||||||
choices = response.choices
|
choices = response.choices
|
||||||
if isinstance(response, Completion):
|
if isinstance(response, Completion):
|
||||||
return [choice.text for choice in choices]
|
return [choice.text for choice in choices] # type: ignore [union-attr]
|
||||||
|
|
||||||
if TOOL_ENABLED:
|
if TOOL_ENABLED:
|
||||||
return [
|
return [ # type: ignore [return-value]
|
||||||
choice.message
|
choice.message # type: ignore [union-attr]
|
||||||
if choice.message.function_call is not None or choice.message.tool_calls is not None
|
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
|
||||||
else choice.message.content
|
else choice.message.content # type: ignore [union-attr]
|
||||||
for choice in choices
|
for choice in choices
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
return [
|
return [ # type: ignore [return-value]
|
||||||
choice.message if choice.message.function_call is not None else choice.message.content
|
choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr]
|
||||||
for choice in choices
|
for choice in choices
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
from dotenv import find_dotenv, load_dotenv
|
from dotenv import find_dotenv, load_dotenv
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ OAI_PRICE1K = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_key(config):
|
def get_key(config: Dict[str, Any]) -> str:
|
||||||
"""Get a unique identifier of a configuration.
|
"""Get a unique identifier of a configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
from unittest.mock import MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
|
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
|
||||||
import sys
|
import sys
|
||||||
@ -13,12 +16,21 @@ except ImportError:
|
|||||||
else:
|
else:
|
||||||
skip = False or skip_openai
|
skip = False or skip_openai
|
||||||
|
|
||||||
|
# raises exception if openai>=1 is installed and something is wrong with imports
|
||||||
|
# otherwise the test will be skipped
|
||||||
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
|
ChoiceDeltaFunctionCall,
|
||||||
|
ChoiceDeltaToolCall,
|
||||||
|
ChoiceDeltaToolCallFunction,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletionMessage # type: ignore [attr-defined]
|
||||||
|
|
||||||
KEY_LOC = "notebook"
|
KEY_LOC = "notebook"
|
||||||
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
|
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
def test_aoai_chat_completion_stream():
|
def test_aoai_chat_completion_stream() -> None:
|
||||||
config_list = config_list_from_json(
|
config_list = config_list_from_json(
|
||||||
env_or_file=OAI_CONFIG_LIST,
|
env_or_file=OAI_CONFIG_LIST,
|
||||||
file_location=KEY_LOC,
|
file_location=KEY_LOC,
|
||||||
@ -31,7 +43,7 @@ def test_aoai_chat_completion_stream():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
def test_chat_completion_stream():
|
def test_chat_completion_stream() -> None:
|
||||||
config_list = config_list_from_json(
|
config_list = config_list_from_json(
|
||||||
env_or_file=OAI_CONFIG_LIST,
|
env_or_file=OAI_CONFIG_LIST,
|
||||||
file_location=KEY_LOC,
|
file_location=KEY_LOC,
|
||||||
@ -43,8 +55,147 @@ def test_chat_completion_stream():
|
|||||||
print(client.extract_text_or_completion_object(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
|
|
||||||
|
# no need for OpenAI, works with any model
|
||||||
|
def test__update_dict_from_chunk() -> None:
|
||||||
|
# dictionaries and lists are not supported
|
||||||
|
mock = MagicMock()
|
||||||
|
empty_collections: List[Union[List[Any], Dict[str, Any]]] = [{}, []]
|
||||||
|
for c in empty_collections:
|
||||||
|
mock.c = c
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
OpenAIWrapper._update_dict_from_chunk(mock, {}, "c")
|
||||||
|
|
||||||
|
org_d: Dict[str, Any] = {}
|
||||||
|
for i, v in enumerate([0, 1, False, True, 0.0, 1.0]):
|
||||||
|
field = "abcedfghijklmnopqrstuvwxyz"[i]
|
||||||
|
setattr(mock, field, v)
|
||||||
|
|
||||||
|
d = org_d.copy()
|
||||||
|
OpenAIWrapper._update_dict_from_chunk(mock, d, field)
|
||||||
|
|
||||||
|
org_d[field] = v
|
||||||
|
assert d == org_d
|
||||||
|
|
||||||
|
mock.s = "beginning"
|
||||||
|
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
|
||||||
|
assert d["s"] == "beginning"
|
||||||
|
|
||||||
|
mock.s = " and"
|
||||||
|
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
|
||||||
|
assert d["s"] == "beginning and"
|
||||||
|
|
||||||
|
mock.s = " end"
|
||||||
|
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
|
||||||
|
assert d["s"] == "beginning and end"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
def test_chat_functions_stream():
|
def test__update_function_call_from_chunk() -> None:
|
||||||
|
function_call_chunks = [
|
||||||
|
ChoiceDeltaFunctionCall(arguments=None, name="get_current_weather"),
|
||||||
|
ChoiceDeltaFunctionCall(arguments='{"', name=None),
|
||||||
|
ChoiceDeltaFunctionCall(arguments="location", name=None),
|
||||||
|
ChoiceDeltaFunctionCall(arguments='":"', name=None),
|
||||||
|
ChoiceDeltaFunctionCall(arguments="San", name=None),
|
||||||
|
ChoiceDeltaFunctionCall(arguments=" Francisco", name=None),
|
||||||
|
ChoiceDeltaFunctionCall(arguments='"}', name=None),
|
||||||
|
]
|
||||||
|
expected = {"name": "get_current_weather", "arguments": '{"location":"San Francisco"}'}
|
||||||
|
|
||||||
|
full_function_call = None
|
||||||
|
completion_tokens = 0
|
||||||
|
for function_call_chunk in function_call_chunks:
|
||||||
|
# print(f"{function_call_chunk=}")
|
||||||
|
full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
|
||||||
|
function_call_chunk=function_call_chunk,
|
||||||
|
full_function_call=full_function_call,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{full_function_call=}")
|
||||||
|
print(f"{completion_tokens=}")
|
||||||
|
|
||||||
|
assert full_function_call == expected
|
||||||
|
assert completion_tokens == len(function_call_chunks)
|
||||||
|
|
||||||
|
ChatCompletionMessage(role="assistant", function_call=full_function_call, content=None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
|
def test__update_tool_calls_from_chunk() -> None:
|
||||||
|
tool_calls_chunks = [
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
id="call_D2HOWGMekmkxXu9Ix3DUqJRv",
|
||||||
|
function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
|
||||||
|
type="function",
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "S', name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="an F", name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="ranci", name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="sco, C", name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='A"}', name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=1,
|
||||||
|
id="call_22HgJep4nwoKU3UOr96xaLmd",
|
||||||
|
function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
|
||||||
|
type="function",
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "N', name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ew Y", name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ork, ", name=None), type=None
|
||||||
|
),
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='NY"}', name=None), type=None
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
full_tool_calls: List[Optional[Dict[str, Any]]] = [None, None]
|
||||||
|
completion_tokens = 0
|
||||||
|
for tool_calls_chunk in tool_calls_chunks:
|
||||||
|
index = tool_calls_chunk.index
|
||||||
|
full_tool_calls[index], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk(
|
||||||
|
tool_calls_chunk=tool_calls_chunk,
|
||||||
|
full_tool_call=full_tool_calls[index],
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{full_tool_calls=}")
|
||||||
|
print(f"{completion_tokens=}")
|
||||||
|
|
||||||
|
ChatCompletionMessage(role="assistant", tool_calls=full_tool_calls, content=None)
|
||||||
|
|
||||||
|
|
||||||
|
# todo: remove when OpenAI removes functions from the API
|
||||||
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
|
def test_chat_functions_stream() -> None:
|
||||||
config_list = config_list_from_json(
|
config_list = config_list_from_json(
|
||||||
env_or_file=OAI_CONFIG_LIST,
|
env_or_file=OAI_CONFIG_LIST,
|
||||||
file_location=KEY_LOC,
|
file_location=KEY_LOC,
|
||||||
@ -76,8 +227,63 @@ def test_chat_functions_stream():
|
|||||||
print(client.extract_text_or_completion_object(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
|
|
||||||
|
# test for tool support instead of the deprecated function calls
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
def test_completion_stream():
|
def test_chat_tools_stream() -> None:
|
||||||
|
config_list = config_list_from_json(
|
||||||
|
env_or_file=OAI_CONFIG_LIST,
|
||||||
|
file_location=KEY_LOC,
|
||||||
|
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||||
|
)
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
print(f"{config_list=}")
|
||||||
|
client = OpenAIWrapper(config_list=config_list)
|
||||||
|
response = client.create(
|
||||||
|
# the intention is to trigger two tool invocations as a response to a single message
|
||||||
|
messages=[{"role": "user", "content": "What's the weather like today in San Francisco and New York?"}],
|
||||||
|
tools=tools,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
print(f"{response=}")
|
||||||
|
print(f"{type(response)=}")
|
||||||
|
print(f"{client.extract_text_or_completion_object(response)=}")
|
||||||
|
# check response
|
||||||
|
choices = response.choices
|
||||||
|
assert isinstance(choices, list)
|
||||||
|
assert len(choices) == 1
|
||||||
|
choice = choices[0]
|
||||||
|
assert choice.finish_reason == "tool_calls"
|
||||||
|
message = choice.message
|
||||||
|
tool_calls = message.tool_calls
|
||||||
|
assert isinstance(tool_calls, list)
|
||||||
|
assert len(tool_calls) == 2
|
||||||
|
arguments = [tool_call.function.arguments for tool_call in tool_calls]
|
||||||
|
locations = [json.loads(argument)["location"] for argument in arguments]
|
||||||
|
print(f"{locations=}")
|
||||||
|
assert any(["San Francisco" in location for location in locations])
|
||||||
|
assert any(["New York" in location for location in locations])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
|
def test_completion_stream() -> None:
|
||||||
config_list = config_list_openai_aoai(KEY_LOC)
|
config_list = config_list_openai_aoai(KEY_LOC)
|
||||||
client = OpenAIWrapper(config_list=config_list)
|
client = OpenAIWrapper(config_list=config_list)
|
||||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
|
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user