Added support for streaming tool calls (#1184)

* added support for streaming tool calls

* bug fix: removed tmp assert

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Davor Runje 2024-01-11 05:34:51 +01:00 committed by GitHub
parent e7cdae63fd
commit 56aed2d3d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 445 additions and 74 deletions

View File

@ -2,21 +2,36 @@ from __future__ import annotations
import os
import sys
from typing import List, Optional, Dict, Callable, Union
from typing import Any, List, Optional, Dict, Callable, Tuple, Union
import logging
import inspect
from flaml.automl.logger import logger_formatter
from pydantic import ValidationError
from pydantic import BaseModel
from autogen.oai import completion
from autogen.oai.openai_utils import get_key, OAI_PRICE1K
from autogen.token_count_utils import count_token
from autogen._pydantic import model_dump
TOOL_ENABLED = False
try:
import openai
except ImportError:
ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
OpenAI = object
else:
# raises exception if openai>=1 is installed and something is wrong with imports
from openai import OpenAI, APIError, __version__ as OPENAIVERSION
from openai.resources import Completions
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
ChoiceDeltaFunctionCall,
)
from openai.types.completion import Completion
from openai.types.completion_usage import CompletionUsage
import diskcache
@ -24,9 +39,7 @@ try:
if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True
ERROR = None
except ImportError:
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
OpenAI = object
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
@ -41,10 +54,10 @@ class OpenAIWrapper:
cache_path_root: str = ".cache"
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
total_usage_summary: Dict = None
actual_usage_summary: Dict = None
total_usage_summary: Optional[Dict[str, Any]] = None
actual_usage_summary: Optional[Dict[str, Any]] = None
def __init__(self, *, config_list: List[Dict] = None, **base_config):
def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any):
"""
Args:
config_list: a list of config dicts to override the base_config.
@ -81,7 +94,9 @@ class OpenAIWrapper:
logger.warning("openai client was provided with an empty config_list, which may not be intended.")
if config_list:
config_list = [config.copy() for config in config_list] # make a copy before modifying
self._clients = [self._client(config, openai_config) for config in config_list] # could modify the config
self._clients: List[OpenAI] = [
self._client(config, openai_config) for config in config_list
] # could modify the config
self._config_list = [
{**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}}
for config in config_list
@ -90,7 +105,9 @@ class OpenAIWrapper:
self._clients = [self._client(extra_kwargs, openai_config)]
self._config_list = [extra_kwargs]
def _process_for_azure(self, config: Dict, extra_kwargs: Dict, segment: str = "default"):
def _process_for_azure(
self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default"
) -> None:
# deal with api_version
query_segment = f"{segment}_query"
headers_segment = f"{segment}_headers"
@ -123,20 +140,20 @@ class OpenAIWrapper:
if not base_url.endswith(suffix):
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix
def _separate_openai_config(self, config):
def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Separate the config into openai_config and extra_kwargs."""
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
self._process_for_azure(openai_config, extra_kwargs)
return openai_config, extra_kwargs
def _separate_create_config(self, config):
def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Separate the config into create_config and extra_kwargs."""
create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
return create_config, extra_kwargs
def _client(self, config, openai_config):
def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI:
"""Create a client with the given config to override openai_config,
after removing extra kwargs.
"""
@ -148,21 +165,21 @@ class OpenAIWrapper:
@classmethod
def instantiate(
cls,
template: str | Callable | None,
context: Optional[Dict] = None,
template: Optional[Union[str, Callable[[Dict[str, Any]], str]]],
context: Optional[Dict[str, Any]] = None,
allow_format_str_template: Optional[bool] = False,
):
) -> Optional[str]:
if not context or template is None:
return template
return template # type: ignore [return-value]
if isinstance(template, str):
return template.format(**context) if allow_format_str_template else template
return template(context)
def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict:
def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Prime the create_config with additional_kwargs."""
# Validate the config
prompt = create_config.get("prompt")
messages = create_config.get("messages")
prompt: Optional[str] = create_config.get("prompt")
messages: Optional[List[Dict[str, Any]]] = create_config.get("messages")
if (prompt is None) == (messages is None):
raise ValueError("Either prompt or messages should be in create config but not both.")
context = extra_kwargs.get("context")
@ -185,11 +202,11 @@ class OpenAIWrapper:
}
if m.get("content")
else m
for m in messages
for m in messages # type: ignore [union-attr]
]
return params
def create(self, **config):
def create(self, **config: Any) -> ChatCompletion:
"""Make a completion for a given config using openai's clients.
Besides the kwargs allowed in openai's client, we allow the following additional kwargs.
The config in each client will be overridden by the config.
@ -239,11 +256,11 @@ class OpenAIWrapper:
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
# Try to get the response from cache
key = get_key(params)
response = cache.get(key, None)
response: ChatCompletion = cache.get(key, None)
if response is not None:
try:
response.cost
response.cost # type: ignore [attr-defined]
except AttributeError:
# update attribute if cost is not calculated
response.cost = self.cost(response)
@ -264,7 +281,7 @@ class OpenAIWrapper:
if error_code == "content_filter":
# raise the error for content_filter
raise
logger.debug(f"config {i} failed", exc_info=1)
logger.debug(f"config {i} failed", exc_info=True)
if i == last:
raise
else:
@ -284,9 +301,129 @@ class OpenAIWrapper:
response.pass_filter = pass_filter
return response
continue # filter is not passed; try the next config
raise RuntimeError("Should not reach here.")
def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
@staticmethod
def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
"""Update the dict from the chunk.
Reads `chunk.field` and if present updates `d[field]` accordingly.
Args:
chunk: The chunk.
d: The dict to be updated in place.
field: The field.
Returns:
The updated dict.
"""
completion_tokens = 0
assert isinstance(d, dict), d
if hasattr(chunk, field) and getattr(chunk, field) is not None:
new_value = getattr(chunk, field)
if isinstance(new_value, list) or isinstance(new_value, dict):
raise NotImplementedError(
f"Field {field} is a list or dict, which is currently not supported. "
"Only string and numbers are supported."
)
if field not in d:
d[field] = ""
if isinstance(new_value, str):
d[field] += getattr(chunk, field)
else:
d[field] = new_value
completion_tokens = 1
return completion_tokens
@staticmethod
def _update_function_call_from_chunk(
function_call_chunk: Union[ChoiceDeltaToolCallFunction, ChoiceDeltaFunctionCall],
full_function_call: Optional[Dict[str, Any]],
completion_tokens: int,
) -> Tuple[Dict[str, Any], int]:
"""Update the function call from the chunk.
Args:
function_call_chunk: The function call chunk.
full_function_call: The full function call.
completion_tokens: The number of completion tokens.
Returns:
The updated full function call and the updated number of completion tokens.
"""
# Handle function call
if function_call_chunk:
if full_function_call is None:
full_function_call = {}
for field in ["name", "arguments"]:
completion_tokens += OpenAIWrapper._update_dict_from_chunk(
function_call_chunk, full_function_call, field
)
if full_function_call:
return full_function_call, completion_tokens
else:
raise RuntimeError("Function call is not found, this should not happen.")
@staticmethod
def _update_tool_calls_from_chunk(
tool_calls_chunk: ChoiceDeltaToolCall,
full_tool_call: Optional[Dict[str, Any]],
completion_tokens: int,
) -> Tuple[Dict[str, Any], int]:
"""Update the tool call from the chunk.
Args:
tool_call_chunk: The tool call chunk.
full_tool_call: The full tool call.
completion_tokens: The number of completion tokens.
Returns:
The updated full tool call and the updated number of completion tokens.
"""
# future proofing for when tool calls other than function calls are supported
if tool_calls_chunk.type and tool_calls_chunk.type != "function":
raise NotImplementedError(
f"Tool call type {tool_calls_chunk.type} is currently not supported. "
"Only function calls are supported."
)
# Handle tool call
assert full_tool_call is None or isinstance(full_tool_call, dict), full_tool_call
if tool_calls_chunk:
if full_tool_call is None:
full_tool_call = {}
for field in ["index", "id", "type"]:
completion_tokens += OpenAIWrapper._update_dict_from_chunk(tool_calls_chunk, full_tool_call, field)
if hasattr(tool_calls_chunk, "function") and tool_calls_chunk.function:
if "function" not in full_tool_call:
full_tool_call["function"] = None
full_tool_call["function"], completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
tool_calls_chunk.function, full_tool_call["function"], completion_tokens
)
if full_tool_call:
return full_tool_call, completion_tokens
else:
raise RuntimeError("Tool call is not found, this should not happen.")
def _completions_create(self, client: OpenAI, params: Dict[str, Any]) -> ChatCompletion:
"""Create a completion for a given config using openai's client.
Args:
client: The openai client.
params: The params for the completion.
Returns:
The completion.
"""
completions: Completions = client.chat.completions if "messages" in params else client.completions # type: ignore [attr-defined]
# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
response_contents = [""] * params.get("n", 1)
@ -297,31 +434,52 @@ class OpenAIWrapper:
print("\033[32m", end="")
# Prepare for potential function call
full_function_call = None
full_function_call: Optional[Dict[str, Any]] = None
full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None
# Send the chat completion request to OpenAI's API and process the response in chunks
for chunk in completions.create(**params):
if chunk.choices:
for choice in chunk.choices:
content = choice.delta.content
function_call_chunk = choice.delta.function_call
tool_calls_chunks = choice.delta.tool_calls
finish_reasons[choice.index] = choice.finish_reason
# todo: remove this after function calls are removed from the API
# the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail
# begin block
function_call_chunk = (
choice.delta.function_call if hasattr(choice.delta, "function_call") else None
)
# Handle function call
if function_call_chunk:
if hasattr(function_call_chunk, "name") and function_call_chunk.name:
if full_function_call is None:
full_function_call = {"name": "", "arguments": ""}
full_function_call["name"] += function_call_chunk.name
completion_tokens += 1
if hasattr(function_call_chunk, "arguments") and function_call_chunk.arguments:
full_function_call["arguments"] += function_call_chunk.arguments
completion_tokens += 1
if choice.finish_reason == "function_call":
# Need something here? I don't think so.
pass
# Handle function call
if function_call_chunk:
full_function_call, completion_tokens = self._update_function_call_from_chunk(
function_call_chunk, full_function_call, completion_tokens
)
if not content:
continue
# End handle function call
# end block
# Handle tool calls
if tool_calls_chunks:
for tool_calls_chunk in tool_calls_chunks:
# the current tool call to be reconstructed
ix = tool_calls_chunk.index
if full_tool_calls is None:
full_tool_calls = []
if ix >= len(full_tool_calls):
# in case ix is not sequential
full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1)
full_tool_calls[ix], completion_tokens = self._update_tool_calls_from_chunk(
tool_calls_chunk, full_tool_calls[ix], completion_tokens
)
if not content:
continue
# End handle tool calls
# If content is present, print it to the terminal and update response variables
if content is not None:
@ -329,7 +487,8 @@ class OpenAIWrapper:
response_contents[choice.index] += content
completion_tokens += 1
else:
print()
# print()
pass
# Reset the terminal text color
print("\033[0m\n")
@ -356,17 +515,23 @@ class OpenAIWrapper:
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=full_function_call
role="assistant",
content=response_contents[i],
function_call=full_function_call,
tool_calls=full_tool_calls,
),
logprobs=None,
)
else:
# OpenAI versions below 1.5.0
choice = Choice(
choice = Choice( # type: ignore [call-arg]
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=full_function_call
role="assistant",
content=response_contents[i],
function_call=full_function_call,
tool_calls=full_tool_calls,
),
)
@ -379,7 +544,7 @@ class OpenAIWrapper:
return response
def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None:
def _update_usage_summary(self, response: Union[ChatCompletion, Completion], use_cache: bool) -> None:
"""Update the usage summary.
Usage is calculated no matter filter is passed or not.
@ -391,17 +556,17 @@ class OpenAIWrapper:
usage.completion_tokens = 0 if usage.completion_tokens is None else usage.completion_tokens
usage.total_tokens = 0 if usage.total_tokens is None else usage.total_tokens
except (AttributeError, AssertionError):
logger.debug("Usage attribute is not found in the response.", exc_info=1)
logger.debug("Usage attribute is not found in the response.", exc_info=True)
return
def update_usage(usage_summary):
def update_usage(usage_summary: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if usage_summary is None:
usage_summary = {"total_cost": response.cost}
usage_summary = {"total_cost": response.cost} # type: ignore [union-attr]
else:
usage_summary["total_cost"] += response.cost
usage_summary["total_cost"] += response.cost # type: ignore [union-attr]
usage_summary[response.model] = {
"cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost,
"cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost, # type: ignore [union-attr]
"prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0) + usage.prompt_tokens,
"completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0)
+ usage.completion_tokens,
@ -416,7 +581,7 @@ class OpenAIWrapper:
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
"""Print the usage summary."""
def print_usage(usage_summary, usage_type="total"):
def print_usage(usage_summary: Optional[Dict[str, Any]], usage_type: str = "total") -> None:
word_from_type = "including" if usage_type == "total" else "excluding"
if usage_summary is None:
print("No actual cost incurred (all completions are using cache).", flush=True)
@ -475,20 +640,20 @@ class OpenAIWrapper:
model = response.model
if model not in OAI_PRICE1K:
# TODO: add logging to warn that the model is not found
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=1)
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
return 0
n_input_tokens = response.usage.prompt_tokens
n_output_tokens = response.usage.completion_tokens
n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr]
n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr]
tmp_price1K = OAI_PRICE1K[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return]
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]
@classmethod
def extract_text_or_completion_object(
cls, response: ChatCompletion | Completion
cls, response: Union[ChatCompletion, Completion]
) -> Union[List[str], List[ChatCompletionMessage]]:
"""Extract the text or ChatCompletion objects from a completion or chat response.
@ -500,18 +665,18 @@ class OpenAIWrapper:
"""
choices = response.choices
if isinstance(response, Completion):
return [choice.text for choice in choices]
return [choice.text for choice in choices] # type: ignore [union-attr]
if TOOL_ENABLED:
return [
choice.message
if choice.message.function_call is not None or choice.message.tool_calls is not None
else choice.message.content
return [ # type: ignore [return-value]
choice.message # type: ignore [union-attr]
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
else choice.message.content # type: ignore [union-attr]
for choice in choices
]
else:
return [
choice.message if choice.message.function_call is not None else choice.message.content
return [ # type: ignore [return-value]
choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr]
for choice in choices
]

View File

@ -3,7 +3,7 @@ import logging
import os
import tempfile
from pathlib import Path
from typing import Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union
from dotenv import find_dotenv, load_dotenv
@ -50,7 +50,7 @@ OAI_PRICE1K = {
}
def get_key(config):
def get_key(config: Dict[str, Any]) -> str:
"""Get a unique identifier of a configuration.
Args:

View File

@ -1,3 +1,6 @@
import json
from typing import Any, Dict, List, Literal, Optional, Union
from unittest.mock import MagicMock
import pytest
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
import sys
@ -13,12 +16,21 @@ except ImportError:
else:
skip = False or skip_openai
# raises exception if openai>=1 is installed and something is wrong with imports
# otherwise the test will be skipped
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaFunctionCall,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion import ChatCompletionMessage # type: ignore [attr-defined]
KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_aoai_chat_completion_stream():
def test_aoai_chat_completion_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
@ -31,7 +43,7 @@ def test_aoai_chat_completion_stream():
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_completion_stream():
def test_chat_completion_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
@ -43,8 +55,147 @@ def test_chat_completion_stream():
print(client.extract_text_or_completion_object(response))
# no need for OpenAI, works with any model
def test__update_dict_from_chunk() -> None:
# dictionaries and lists are not supported
mock = MagicMock()
empty_collections: List[Union[List[Any], Dict[str, Any]]] = [{}, []]
for c in empty_collections:
mock.c = c
with pytest.raises(NotImplementedError):
OpenAIWrapper._update_dict_from_chunk(mock, {}, "c")
org_d: Dict[str, Any] = {}
for i, v in enumerate([0, 1, False, True, 0.0, 1.0]):
field = "abcedfghijklmnopqrstuvwxyz"[i]
setattr(mock, field, v)
d = org_d.copy()
OpenAIWrapper._update_dict_from_chunk(mock, d, field)
org_d[field] = v
assert d == org_d
mock.s = "beginning"
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
assert d["s"] == "beginning"
mock.s = " and"
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
assert d["s"] == "beginning and"
mock.s = " end"
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
assert d["s"] == "beginning and end"
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_functions_stream():
def test__update_function_call_from_chunk() -> None:
function_call_chunks = [
ChoiceDeltaFunctionCall(arguments=None, name="get_current_weather"),
ChoiceDeltaFunctionCall(arguments='{"', name=None),
ChoiceDeltaFunctionCall(arguments="location", name=None),
ChoiceDeltaFunctionCall(arguments='":"', name=None),
ChoiceDeltaFunctionCall(arguments="San", name=None),
ChoiceDeltaFunctionCall(arguments=" Francisco", name=None),
ChoiceDeltaFunctionCall(arguments='"}', name=None),
]
expected = {"name": "get_current_weather", "arguments": '{"location":"San Francisco"}'}
full_function_call = None
completion_tokens = 0
for function_call_chunk in function_call_chunks:
# print(f"{function_call_chunk=}")
full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
function_call_chunk=function_call_chunk,
full_function_call=full_function_call,
completion_tokens=completion_tokens,
)
print(f"{full_function_call=}")
print(f"{completion_tokens=}")
assert full_function_call == expected
assert completion_tokens == len(function_call_chunks)
ChatCompletionMessage(role="assistant", function_call=full_function_call, content=None)
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test__update_tool_calls_from_chunk() -> None:
tool_calls_chunks = [
ChoiceDeltaToolCall(
index=0,
id="call_D2HOWGMekmkxXu9Ix3DUqJRv",
function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
type="function",
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "S', name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="an F", name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="ranci", name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="sco, C", name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='A"}', name=None), type=None
),
ChoiceDeltaToolCall(
index=1,
id="call_22HgJep4nwoKU3UOr96xaLmd",
function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
type="function",
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "N', name=None), type=None
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ew Y", name=None), type=None
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ork, ", name=None), type=None
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='NY"}', name=None), type=None
),
]
full_tool_calls: List[Optional[Dict[str, Any]]] = [None, None]
completion_tokens = 0
for tool_calls_chunk in tool_calls_chunks:
index = tool_calls_chunk.index
full_tool_calls[index], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk(
tool_calls_chunk=tool_calls_chunk,
full_tool_call=full_tool_calls[index],
completion_tokens=completion_tokens,
)
print(f"{full_tool_calls=}")
print(f"{completion_tokens=}")
ChatCompletionMessage(role="assistant", tool_calls=full_tool_calls, content=None)
# todo: remove when OpenAI removes functions from the API
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_functions_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
@ -76,8 +227,63 @@ def test_chat_functions_stream():
print(client.extract_text_or_completion_object(response))
# test for tool support instead of the deprecated function calls
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_completion_stream():
def test_chat_tools_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
"required": ["location"],
},
},
},
]
print(f"{config_list=}")
client = OpenAIWrapper(config_list=config_list)
response = client.create(
# the intention is to trigger two tool invocations as a response to a single message
messages=[{"role": "user", "content": "What's the weather like today in San Francisco and New York?"}],
tools=tools,
stream=True,
)
print(f"{response=}")
print(f"{type(response)=}")
print(f"{client.extract_text_or_completion_object(response)=}")
# check response
choices = response.choices
assert isinstance(choices, list)
assert len(choices) == 1
choice = choices[0]
assert choice.finish_reason == "tool_calls"
message = choice.message
tool_calls = message.tool_calls
assert isinstance(tool_calls, list)
assert len(tool_calls) == 2
arguments = [tool_call.function.arguments for tool_call in tool_calls]
locations = [json.loads(argument)["location"] for argument in arguments]
print(f"{locations=}")
assert any(["San Francisco" in location for location in locations])
assert any(["New York" in location for location in locations])
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_completion_stream() -> None:
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)