mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-15 03:06:06 +00:00
Added support for streaming tool calls (#1184)
* added support for streaming tool calls * bug fix: removed tmp assert --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
e7cdae63fd
commit
56aed2d3d1
@ -2,21 +2,36 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Callable, Union
|
||||
from typing import Any, List, Optional, Dict, Callable, Tuple, Union
|
||||
import logging
|
||||
import inspect
|
||||
from flaml.automl.logger import logger_formatter
|
||||
from pydantic import ValidationError
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogen.oai import completion
|
||||
|
||||
from autogen.oai.openai_utils import get_key, OAI_PRICE1K
|
||||
from autogen.token_count_utils import count_token
|
||||
from autogen._pydantic import model_dump
|
||||
|
||||
TOOL_ENABLED = False
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
|
||||
OpenAI = object
|
||||
else:
|
||||
# raises exception if openai>=1 is installed and something is wrong with imports
|
||||
from openai import OpenAI, APIError, __version__ as OPENAIVERSION
|
||||
from openai.resources import Completions
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
ChoiceDeltaFunctionCall,
|
||||
)
|
||||
from openai.types.completion import Completion
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
import diskcache
|
||||
@ -24,9 +39,7 @@ try:
|
||||
if openai.__version__ >= "1.1.0":
|
||||
TOOL_ENABLED = True
|
||||
ERROR = None
|
||||
except ImportError:
|
||||
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
|
||||
OpenAI = object
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logger.handlers:
|
||||
# Add the console handler.
|
||||
@ -41,10 +54,10 @@ class OpenAIWrapper:
|
||||
cache_path_root: str = ".cache"
|
||||
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
|
||||
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
|
||||
total_usage_summary: Dict = None
|
||||
actual_usage_summary: Dict = None
|
||||
total_usage_summary: Optional[Dict[str, Any]] = None
|
||||
actual_usage_summary: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __init__(self, *, config_list: List[Dict] = None, **base_config):
|
||||
def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any):
|
||||
"""
|
||||
Args:
|
||||
config_list: a list of config dicts to override the base_config.
|
||||
@ -81,7 +94,9 @@ class OpenAIWrapper:
|
||||
logger.warning("openai client was provided with an empty config_list, which may not be intended.")
|
||||
if config_list:
|
||||
config_list = [config.copy() for config in config_list] # make a copy before modifying
|
||||
self._clients = [self._client(config, openai_config) for config in config_list] # could modify the config
|
||||
self._clients: List[OpenAI] = [
|
||||
self._client(config, openai_config) for config in config_list
|
||||
] # could modify the config
|
||||
self._config_list = [
|
||||
{**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}}
|
||||
for config in config_list
|
||||
@ -90,7 +105,9 @@ class OpenAIWrapper:
|
||||
self._clients = [self._client(extra_kwargs, openai_config)]
|
||||
self._config_list = [extra_kwargs]
|
||||
|
||||
def _process_for_azure(self, config: Dict, extra_kwargs: Dict, segment: str = "default"):
|
||||
def _process_for_azure(
|
||||
self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default"
|
||||
) -> None:
|
||||
# deal with api_version
|
||||
query_segment = f"{segment}_query"
|
||||
headers_segment = f"{segment}_headers"
|
||||
@ -123,20 +140,20 @@ class OpenAIWrapper:
|
||||
if not base_url.endswith(suffix):
|
||||
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix
|
||||
|
||||
def _separate_openai_config(self, config):
|
||||
def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Separate the config into openai_config and extra_kwargs."""
|
||||
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
|
||||
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
|
||||
self._process_for_azure(openai_config, extra_kwargs)
|
||||
return openai_config, extra_kwargs
|
||||
|
||||
def _separate_create_config(self, config):
|
||||
def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Separate the config into create_config and extra_kwargs."""
|
||||
create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
|
||||
extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
|
||||
return create_config, extra_kwargs
|
||||
|
||||
def _client(self, config, openai_config):
|
||||
def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI:
|
||||
"""Create a client with the given config to override openai_config,
|
||||
after removing extra kwargs.
|
||||
"""
|
||||
@ -148,21 +165,21 @@ class OpenAIWrapper:
|
||||
@classmethod
|
||||
def instantiate(
|
||||
cls,
|
||||
template: str | Callable | None,
|
||||
context: Optional[Dict] = None,
|
||||
template: Optional[Union[str, Callable[[Dict[str, Any]], str]]],
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
allow_format_str_template: Optional[bool] = False,
|
||||
):
|
||||
) -> Optional[str]:
|
||||
if not context or template is None:
|
||||
return template
|
||||
return template # type: ignore [return-value]
|
||||
if isinstance(template, str):
|
||||
return template.format(**context) if allow_format_str_template else template
|
||||
return template(context)
|
||||
|
||||
def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict:
|
||||
def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prime the create_config with additional_kwargs."""
|
||||
# Validate the config
|
||||
prompt = create_config.get("prompt")
|
||||
messages = create_config.get("messages")
|
||||
prompt: Optional[str] = create_config.get("prompt")
|
||||
messages: Optional[List[Dict[str, Any]]] = create_config.get("messages")
|
||||
if (prompt is None) == (messages is None):
|
||||
raise ValueError("Either prompt or messages should be in create config but not both.")
|
||||
context = extra_kwargs.get("context")
|
||||
@ -185,11 +202,11 @@ class OpenAIWrapper:
|
||||
}
|
||||
if m.get("content")
|
||||
else m
|
||||
for m in messages
|
||||
for m in messages # type: ignore [union-attr]
|
||||
]
|
||||
return params
|
||||
|
||||
def create(self, **config):
|
||||
def create(self, **config: Any) -> ChatCompletion:
|
||||
"""Make a completion for a given config using openai's clients.
|
||||
Besides the kwargs allowed in openai's client, we allow the following additional kwargs.
|
||||
The config in each client will be overridden by the config.
|
||||
@ -239,11 +256,11 @@ class OpenAIWrapper:
|
||||
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
|
||||
# Try to get the response from cache
|
||||
key = get_key(params)
|
||||
response = cache.get(key, None)
|
||||
response: ChatCompletion = cache.get(key, None)
|
||||
|
||||
if response is not None:
|
||||
try:
|
||||
response.cost
|
||||
response.cost # type: ignore [attr-defined]
|
||||
except AttributeError:
|
||||
# update attribute if cost is not calculated
|
||||
response.cost = self.cost(response)
|
||||
@ -264,7 +281,7 @@ class OpenAIWrapper:
|
||||
if error_code == "content_filter":
|
||||
# raise the error for content_filter
|
||||
raise
|
||||
logger.debug(f"config {i} failed", exc_info=1)
|
||||
logger.debug(f"config {i} failed", exc_info=True)
|
||||
if i == last:
|
||||
raise
|
||||
else:
|
||||
@ -284,9 +301,129 @@ class OpenAIWrapper:
|
||||
response.pass_filter = pass_filter
|
||||
return response
|
||||
continue # filter is not passed; try the next config
|
||||
raise RuntimeError("Should not reach here.")
|
||||
|
||||
def _completions_create(self, client, params):
|
||||
completions = client.chat.completions if "messages" in params else client.completions
|
||||
@staticmethod
|
||||
def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
|
||||
"""Update the dict from the chunk.
|
||||
|
||||
Reads `chunk.field` and if present updates `d[field]` accordingly.
|
||||
|
||||
Args:
|
||||
chunk: The chunk.
|
||||
d: The dict to be updated in place.
|
||||
field: The field.
|
||||
|
||||
Returns:
|
||||
The updated dict.
|
||||
|
||||
"""
|
||||
completion_tokens = 0
|
||||
assert isinstance(d, dict), d
|
||||
if hasattr(chunk, field) and getattr(chunk, field) is not None:
|
||||
new_value = getattr(chunk, field)
|
||||
if isinstance(new_value, list) or isinstance(new_value, dict):
|
||||
raise NotImplementedError(
|
||||
f"Field {field} is a list or dict, which is currently not supported. "
|
||||
"Only string and numbers are supported."
|
||||
)
|
||||
if field not in d:
|
||||
d[field] = ""
|
||||
if isinstance(new_value, str):
|
||||
d[field] += getattr(chunk, field)
|
||||
else:
|
||||
d[field] = new_value
|
||||
completion_tokens = 1
|
||||
|
||||
return completion_tokens
|
||||
|
||||
@staticmethod
|
||||
def _update_function_call_from_chunk(
|
||||
function_call_chunk: Union[ChoiceDeltaToolCallFunction, ChoiceDeltaFunctionCall],
|
||||
full_function_call: Optional[Dict[str, Any]],
|
||||
completion_tokens: int,
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
"""Update the function call from the chunk.
|
||||
|
||||
Args:
|
||||
function_call_chunk: The function call chunk.
|
||||
full_function_call: The full function call.
|
||||
completion_tokens: The number of completion tokens.
|
||||
|
||||
Returns:
|
||||
The updated full function call and the updated number of completion tokens.
|
||||
|
||||
"""
|
||||
# Handle function call
|
||||
if function_call_chunk:
|
||||
if full_function_call is None:
|
||||
full_function_call = {}
|
||||
for field in ["name", "arguments"]:
|
||||
completion_tokens += OpenAIWrapper._update_dict_from_chunk(
|
||||
function_call_chunk, full_function_call, field
|
||||
)
|
||||
|
||||
if full_function_call:
|
||||
return full_function_call, completion_tokens
|
||||
else:
|
||||
raise RuntimeError("Function call is not found, this should not happen.")
|
||||
|
||||
@staticmethod
|
||||
def _update_tool_calls_from_chunk(
|
||||
tool_calls_chunk: ChoiceDeltaToolCall,
|
||||
full_tool_call: Optional[Dict[str, Any]],
|
||||
completion_tokens: int,
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
"""Update the tool call from the chunk.
|
||||
|
||||
Args:
|
||||
tool_call_chunk: The tool call chunk.
|
||||
full_tool_call: The full tool call.
|
||||
completion_tokens: The number of completion tokens.
|
||||
|
||||
Returns:
|
||||
The updated full tool call and the updated number of completion tokens.
|
||||
|
||||
"""
|
||||
# future proofing for when tool calls other than function calls are supported
|
||||
if tool_calls_chunk.type and tool_calls_chunk.type != "function":
|
||||
raise NotImplementedError(
|
||||
f"Tool call type {tool_calls_chunk.type} is currently not supported. "
|
||||
"Only function calls are supported."
|
||||
)
|
||||
|
||||
# Handle tool call
|
||||
assert full_tool_call is None or isinstance(full_tool_call, dict), full_tool_call
|
||||
if tool_calls_chunk:
|
||||
if full_tool_call is None:
|
||||
full_tool_call = {}
|
||||
for field in ["index", "id", "type"]:
|
||||
completion_tokens += OpenAIWrapper._update_dict_from_chunk(tool_calls_chunk, full_tool_call, field)
|
||||
|
||||
if hasattr(tool_calls_chunk, "function") and tool_calls_chunk.function:
|
||||
if "function" not in full_tool_call:
|
||||
full_tool_call["function"] = None
|
||||
|
||||
full_tool_call["function"], completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
|
||||
tool_calls_chunk.function, full_tool_call["function"], completion_tokens
|
||||
)
|
||||
|
||||
if full_tool_call:
|
||||
return full_tool_call, completion_tokens
|
||||
else:
|
||||
raise RuntimeError("Tool call is not found, this should not happen.")
|
||||
|
||||
def _completions_create(self, client: OpenAI, params: Dict[str, Any]) -> ChatCompletion:
|
||||
"""Create a completion for a given config using openai's client.
|
||||
|
||||
Args:
|
||||
client: The openai client.
|
||||
params: The params for the completion.
|
||||
|
||||
Returns:
|
||||
The completion.
|
||||
"""
|
||||
completions: Completions = client.chat.completions if "messages" in params else client.completions # type: ignore [attr-defined]
|
||||
# If streaming is enabled and has messages, then iterate over the chunks of the response.
|
||||
if params.get("stream", False) and "messages" in params:
|
||||
response_contents = [""] * params.get("n", 1)
|
||||
@ -297,31 +434,52 @@ class OpenAIWrapper:
|
||||
print("\033[32m", end="")
|
||||
|
||||
# Prepare for potential function call
|
||||
full_function_call = None
|
||||
full_function_call: Optional[Dict[str, Any]] = None
|
||||
full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None
|
||||
|
||||
# Send the chat completion request to OpenAI's API and process the response in chunks
|
||||
for chunk in completions.create(**params):
|
||||
if chunk.choices:
|
||||
for choice in chunk.choices:
|
||||
content = choice.delta.content
|
||||
function_call_chunk = choice.delta.function_call
|
||||
tool_calls_chunks = choice.delta.tool_calls
|
||||
finish_reasons[choice.index] = choice.finish_reason
|
||||
|
||||
# todo: remove this after function calls are removed from the API
|
||||
# the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail
|
||||
# begin block
|
||||
function_call_chunk = (
|
||||
choice.delta.function_call if hasattr(choice.delta, "function_call") else None
|
||||
)
|
||||
# Handle function call
|
||||
if function_call_chunk:
|
||||
if hasattr(function_call_chunk, "name") and function_call_chunk.name:
|
||||
if full_function_call is None:
|
||||
full_function_call = {"name": "", "arguments": ""}
|
||||
full_function_call["name"] += function_call_chunk.name
|
||||
completion_tokens += 1
|
||||
if hasattr(function_call_chunk, "arguments") and function_call_chunk.arguments:
|
||||
full_function_call["arguments"] += function_call_chunk.arguments
|
||||
completion_tokens += 1
|
||||
if choice.finish_reason == "function_call":
|
||||
# Need something here? I don't think so.
|
||||
pass
|
||||
if not content:
|
||||
continue
|
||||
# End handle function call
|
||||
# Handle function call
|
||||
if function_call_chunk:
|
||||
full_function_call, completion_tokens = self._update_function_call_from_chunk(
|
||||
function_call_chunk, full_function_call, completion_tokens
|
||||
)
|
||||
if not content:
|
||||
continue
|
||||
# end block
|
||||
|
||||
# Handle tool calls
|
||||
if tool_calls_chunks:
|
||||
for tool_calls_chunk in tool_calls_chunks:
|
||||
# the current tool call to be reconstructed
|
||||
ix = tool_calls_chunk.index
|
||||
if full_tool_calls is None:
|
||||
full_tool_calls = []
|
||||
if ix >= len(full_tool_calls):
|
||||
# in case ix is not sequential
|
||||
full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1)
|
||||
|
||||
full_tool_calls[ix], completion_tokens = self._update_tool_calls_from_chunk(
|
||||
tool_calls_chunk, full_tool_calls[ix], completion_tokens
|
||||
)
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# End handle tool calls
|
||||
|
||||
# If content is present, print it to the terminal and update response variables
|
||||
if content is not None:
|
||||
@ -329,7 +487,8 @@ class OpenAIWrapper:
|
||||
response_contents[choice.index] += content
|
||||
completion_tokens += 1
|
||||
else:
|
||||
print()
|
||||
# print()
|
||||
pass
|
||||
|
||||
# Reset the terminal text color
|
||||
print("\033[0m\n")
|
||||
@ -356,17 +515,23 @@ class OpenAIWrapper:
|
||||
index=i,
|
||||
finish_reason=finish_reasons[i],
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=response_contents[i], function_call=full_function_call
|
||||
role="assistant",
|
||||
content=response_contents[i],
|
||||
function_call=full_function_call,
|
||||
tool_calls=full_tool_calls,
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
# OpenAI versions below 1.5.0
|
||||
choice = Choice(
|
||||
choice = Choice( # type: ignore [call-arg]
|
||||
index=i,
|
||||
finish_reason=finish_reasons[i],
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=response_contents[i], function_call=full_function_call
|
||||
role="assistant",
|
||||
content=response_contents[i],
|
||||
function_call=full_function_call,
|
||||
tool_calls=full_tool_calls,
|
||||
),
|
||||
)
|
||||
|
||||
@ -379,7 +544,7 @@ class OpenAIWrapper:
|
||||
|
||||
return response
|
||||
|
||||
def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None:
|
||||
def _update_usage_summary(self, response: Union[ChatCompletion, Completion], use_cache: bool) -> None:
|
||||
"""Update the usage summary.
|
||||
|
||||
Usage is calculated no matter filter is passed or not.
|
||||
@ -391,17 +556,17 @@ class OpenAIWrapper:
|
||||
usage.completion_tokens = 0 if usage.completion_tokens is None else usage.completion_tokens
|
||||
usage.total_tokens = 0 if usage.total_tokens is None else usage.total_tokens
|
||||
except (AttributeError, AssertionError):
|
||||
logger.debug("Usage attribute is not found in the response.", exc_info=1)
|
||||
logger.debug("Usage attribute is not found in the response.", exc_info=True)
|
||||
return
|
||||
|
||||
def update_usage(usage_summary):
|
||||
def update_usage(usage_summary: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
if usage_summary is None:
|
||||
usage_summary = {"total_cost": response.cost}
|
||||
usage_summary = {"total_cost": response.cost} # type: ignore [union-attr]
|
||||
else:
|
||||
usage_summary["total_cost"] += response.cost
|
||||
usage_summary["total_cost"] += response.cost # type: ignore [union-attr]
|
||||
|
||||
usage_summary[response.model] = {
|
||||
"cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost,
|
||||
"cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost, # type: ignore [union-attr]
|
||||
"prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0) + usage.prompt_tokens,
|
||||
"completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0)
|
||||
+ usage.completion_tokens,
|
||||
@ -416,7 +581,7 @@ class OpenAIWrapper:
|
||||
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
|
||||
"""Print the usage summary."""
|
||||
|
||||
def print_usage(usage_summary, usage_type="total"):
|
||||
def print_usage(usage_summary: Optional[Dict[str, Any]], usage_type: str = "total") -> None:
|
||||
word_from_type = "including" if usage_type == "total" else "excluding"
|
||||
if usage_summary is None:
|
||||
print("No actual cost incurred (all completions are using cache).", flush=True)
|
||||
@ -475,20 +640,20 @@ class OpenAIWrapper:
|
||||
model = response.model
|
||||
if model not in OAI_PRICE1K:
|
||||
# TODO: add logging to warn that the model is not found
|
||||
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=1)
|
||||
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
|
||||
return 0
|
||||
|
||||
n_input_tokens = response.usage.prompt_tokens
|
||||
n_output_tokens = response.usage.completion_tokens
|
||||
n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr]
|
||||
n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr]
|
||||
tmp_price1K = OAI_PRICE1K[model]
|
||||
# First value is input token rate, second value is output token rate
|
||||
if isinstance(tmp_price1K, tuple):
|
||||
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
|
||||
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
|
||||
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return]
|
||||
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]
|
||||
|
||||
@classmethod
|
||||
def extract_text_or_completion_object(
|
||||
cls, response: ChatCompletion | Completion
|
||||
cls, response: Union[ChatCompletion, Completion]
|
||||
) -> Union[List[str], List[ChatCompletionMessage]]:
|
||||
"""Extract the text or ChatCompletion objects from a completion or chat response.
|
||||
|
||||
@ -500,18 +665,18 @@ class OpenAIWrapper:
|
||||
"""
|
||||
choices = response.choices
|
||||
if isinstance(response, Completion):
|
||||
return [choice.text for choice in choices]
|
||||
return [choice.text for choice in choices] # type: ignore [union-attr]
|
||||
|
||||
if TOOL_ENABLED:
|
||||
return [
|
||||
choice.message
|
||||
if choice.message.function_call is not None or choice.message.tool_calls is not None
|
||||
else choice.message.content
|
||||
return [ # type: ignore [return-value]
|
||||
choice.message # type: ignore [union-attr]
|
||||
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
|
||||
else choice.message.content # type: ignore [union-attr]
|
||||
for choice in choices
|
||||
]
|
||||
else:
|
||||
return [
|
||||
choice.message if choice.message.function_call is not None else choice.message.content
|
||||
return [ # type: ignore [return-value]
|
||||
choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr]
|
||||
for choice in choices
|
||||
]
|
||||
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
||||
@ -50,7 +50,7 @@ OAI_PRICE1K = {
|
||||
}
|
||||
|
||||
|
||||
def get_key(config):
|
||||
def get_key(config: Dict[str, Any]) -> str:
|
||||
"""Get a unique identifier of a configuration.
|
||||
|
||||
Args:
|
||||
|
@ -1,3 +1,6 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
|
||||
import sys
|
||||
@ -13,12 +16,21 @@ except ImportError:
|
||||
else:
|
||||
skip = False or skip_openai
|
||||
|
||||
# raises exception if openai>=1 is installed and something is wrong with imports
|
||||
# otherwise the test will be skipped
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaFunctionCall,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage # type: ignore [attr-defined]
|
||||
|
||||
KEY_LOC = "notebook"
|
||||
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_aoai_chat_completion_stream():
|
||||
def test_aoai_chat_completion_stream() -> None:
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
@ -31,7 +43,7 @@ def test_aoai_chat_completion_stream():
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_chat_completion_stream():
|
||||
def test_chat_completion_stream() -> None:
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
@ -43,8 +55,147 @@ def test_chat_completion_stream():
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
# no need for OpenAI, works with any model
|
||||
def test__update_dict_from_chunk() -> None:
|
||||
# dictionaries and lists are not supported
|
||||
mock = MagicMock()
|
||||
empty_collections: List[Union[List[Any], Dict[str, Any]]] = [{}, []]
|
||||
for c in empty_collections:
|
||||
mock.c = c
|
||||
with pytest.raises(NotImplementedError):
|
||||
OpenAIWrapper._update_dict_from_chunk(mock, {}, "c")
|
||||
|
||||
org_d: Dict[str, Any] = {}
|
||||
for i, v in enumerate([0, 1, False, True, 0.0, 1.0]):
|
||||
field = "abcedfghijklmnopqrstuvwxyz"[i]
|
||||
setattr(mock, field, v)
|
||||
|
||||
d = org_d.copy()
|
||||
OpenAIWrapper._update_dict_from_chunk(mock, d, field)
|
||||
|
||||
org_d[field] = v
|
||||
assert d == org_d
|
||||
|
||||
mock.s = "beginning"
|
||||
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
|
||||
assert d["s"] == "beginning"
|
||||
|
||||
mock.s = " and"
|
||||
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
|
||||
assert d["s"] == "beginning and"
|
||||
|
||||
mock.s = " end"
|
||||
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
|
||||
assert d["s"] == "beginning and end"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_chat_functions_stream():
|
||||
def test__update_function_call_from_chunk() -> None:
|
||||
function_call_chunks = [
|
||||
ChoiceDeltaFunctionCall(arguments=None, name="get_current_weather"),
|
||||
ChoiceDeltaFunctionCall(arguments='{"', name=None),
|
||||
ChoiceDeltaFunctionCall(arguments="location", name=None),
|
||||
ChoiceDeltaFunctionCall(arguments='":"', name=None),
|
||||
ChoiceDeltaFunctionCall(arguments="San", name=None),
|
||||
ChoiceDeltaFunctionCall(arguments=" Francisco", name=None),
|
||||
ChoiceDeltaFunctionCall(arguments='"}', name=None),
|
||||
]
|
||||
expected = {"name": "get_current_weather", "arguments": '{"location":"San Francisco"}'}
|
||||
|
||||
full_function_call = None
|
||||
completion_tokens = 0
|
||||
for function_call_chunk in function_call_chunks:
|
||||
# print(f"{function_call_chunk=}")
|
||||
full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
|
||||
function_call_chunk=function_call_chunk,
|
||||
full_function_call=full_function_call,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
print(f"{full_function_call=}")
|
||||
print(f"{completion_tokens=}")
|
||||
|
||||
assert full_function_call == expected
|
||||
assert completion_tokens == len(function_call_chunks)
|
||||
|
||||
ChatCompletionMessage(role="assistant", function_call=full_function_call, content=None)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test__update_tool_calls_from_chunk() -> None:
|
||||
tool_calls_chunks = [
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="call_D2HOWGMekmkxXu9Ix3DUqJRv",
|
||||
function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
|
||||
type="function",
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "S', name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="an F", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="ranci", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="sco, C", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='A"}', name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1,
|
||||
id="call_22HgJep4nwoKU3UOr96xaLmd",
|
||||
function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
|
||||
type="function",
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "N', name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ew Y", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ork, ", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='NY"}', name=None), type=None
|
||||
),
|
||||
]
|
||||
|
||||
full_tool_calls: List[Optional[Dict[str, Any]]] = [None, None]
|
||||
completion_tokens = 0
|
||||
for tool_calls_chunk in tool_calls_chunks:
|
||||
index = tool_calls_chunk.index
|
||||
full_tool_calls[index], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk(
|
||||
tool_calls_chunk=tool_calls_chunk,
|
||||
full_tool_call=full_tool_calls[index],
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
print(f"{full_tool_calls=}")
|
||||
print(f"{completion_tokens=}")
|
||||
|
||||
ChatCompletionMessage(role="assistant", tool_calls=full_tool_calls, content=None)
|
||||
|
||||
|
||||
# todo: remove when OpenAI removes functions from the API
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_chat_functions_stream() -> None:
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
@ -76,8 +227,63 @@ def test_chat_functions_stream():
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
# test for tool support instead of the deprecated function calls
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_completion_stream():
|
||||
def test_chat_tools_stream() -> None:
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||
)
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
print(f"{config_list=}")
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(
|
||||
# the intention is to trigger two tool invocations as a response to a single message
|
||||
messages=[{"role": "user", "content": "What's the weather like today in San Francisco and New York?"}],
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
print(f"{response=}")
|
||||
print(f"{type(response)=}")
|
||||
print(f"{client.extract_text_or_completion_object(response)=}")
|
||||
# check response
|
||||
choices = response.choices
|
||||
assert isinstance(choices, list)
|
||||
assert len(choices) == 1
|
||||
choice = choices[0]
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
message = choice.message
|
||||
tool_calls = message.tool_calls
|
||||
assert isinstance(tool_calls, list)
|
||||
assert len(tool_calls) == 2
|
||||
arguments = [tool_call.function.arguments for tool_call in tool_calls]
|
||||
locations = [json.loads(argument)["location"] for argument in arguments]
|
||||
print(f"{locations=}")
|
||||
assert any(["San Francisco" in location for location in locations])
|
||||
assert any(["New York" in location for location in locations])
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_completion_stream() -> None:
|
||||
config_list = config_list_openai_aoai(KEY_LOC)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user