Convert ChatCompletionMessage to Dict after completion (#791)

* update

* update

* update signature

* update

* update

* fix test funccall groupchat

* reverse change

* update

* update

* update

* update

* update

---------

Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Yiran Wu 2023-12-09 22:28:13 -05:00 committed by GitHub
parent a31b240100
commit 9cec541630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 87 additions and 22 deletions

View File

@ -49,7 +49,7 @@ jobs:
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install -e .[mathchat,test]
pip install -e .[test]
pip uninstall -y openai
coverage run -a -m pytest test --ignore=test/agentchat/contrib
coverage xml

View File

@ -403,7 +403,7 @@ Rules:
print(colored(f"Failed to compress the content due to {e}", "red"), flush=True)
return False, None
compressed_message = self.client.extract_text_or_function_call(response)[0]
compressed_message = self.client.extract_text_or_completion_object(response)[0]
assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}"
if self.compress_config["verbose"]:
print(

View File

@ -631,7 +631,12 @@ class ConversableAgent(Agent):
response = client.create(
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages
)
return True, client.extract_text_or_function_call(response)[0]
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
extracted_response = client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str):
extracted_response = extracted_response.model_dump(mode="dict")
return True, extracted_response
async def a_generate_oai_reply(
self,

View File

@ -10,7 +10,9 @@ from flaml.automl.logger import logger_formatter
from autogen.oai.openai_utils import get_key, oai_price1k
from autogen.token_count_utils import count_token
TOOL_ENABLED = False
try:
import openai
from openai import OpenAI, APIError
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
@ -18,6 +20,8 @@ try:
from openai.types.completion_usage import CompletionUsage
import diskcache
if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True
ERROR = None
except ImportError:
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
@ -205,7 +209,7 @@ class OpenAIWrapper:
```python
def yes_or_no_filter(context, response):
return context.get("yes_or_no_choice", False) is False or any(
text in ["Yes.", "No."] for text in client.extract_text_or_function_call(response)
text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
)
```
@ -442,20 +446,32 @@ class OpenAIWrapper:
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
@classmethod
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
"""Extract the text or function calls from a completion or chat response.
def extract_text_or_completion_object(
cls, response: ChatCompletion | Completion
) -> Union[List[str], List[ChatCompletionMessage]]:
"""Extract the text or ChatCompletion objects from a completion or chat response.
Args:
response (ChatCompletion | Completion): The response from openai.
Returns:
A list of text or function calls in the responses.
A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
"""
choices = response.choices
if isinstance(response, Completion):
return [choice.text for choice in choices]
if TOOL_ENABLED:
return [
choice.message if choice.message.function_call is not None else choice.message.content for choice in choices
choice.message
if choice.message.function_call is not None or choice.message.tool_calls is not None
else choice.message.content
for choice in choices
]
else:
return [
choice.message if choice.message.function_call is not None else choice.message.content
for choice in choices
]

View File

@ -48,7 +48,7 @@ def test_eval_math_responses():
functions=functions,
)
print(response)
responses = client.extract_text_or_function_call(response)
responses = client.extract_text_or_completion_object(response)
print(responses[0])
function_call = responses[0].function_call
name, arguments = function_call.name, json.loads(function_call.arguments)

View File

@ -2,12 +2,18 @@ import pytest
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
from test_utils import OAI_CONFIG_LIST, KEY_LOC
TOOL_ENABLED = False
try:
from openai import OpenAI
from openai.types.chat.chat_completion import ChatCompletionMessage
except ImportError:
skip = True
else:
skip = False
import openai
if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@ -24,7 +30,44 @@ def test_aoai_chat_completion():
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip and not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
def test_oai_tool_calling_extraction():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
response = client.create(
messages=[
{
"role": "user",
"content": "What is the weather in San Francisco?",
},
],
tools=[
{
"type": "function",
"function": {
"name": "getCurrentWeather",
"description": "Get the weather in location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
},
}
],
)
print(response)
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@ -36,7 +79,7 @@ def test_chat_completion():
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}])
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@ -45,7 +88,7 @@ def test_completion():
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@ -96,6 +139,7 @@ def test_usage_summary():
if __name__ == "__main__":
test_aoai_chat_completion()
test_oai_tool_calling_extraction()
test_chat_completion()
test_completion()
test_cost()

View File

@ -20,7 +20,7 @@ def test_aoai_chat_completion_stream():
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@ -33,7 +33,7 @@ def test_chat_completion_stream():
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@ -66,7 +66,7 @@ def test_chat_functions_stream():
stream=True,
)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@ -75,7 +75,7 @@ def test_completion_stream():
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
if __name__ == "__main__":

View File

@ -61,7 +61,7 @@ Therefore, some changes are required for users of `pyautogen<0.2`.
from autogen import OpenAIWrapper
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "2+2="}])
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
```
- Inference parameter tuning and inference logging features are currently unavailable in `OpenAIWrapper`. Logging will be added in a future release.
Inference parameter tuning can be done via [`flaml.tune`](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).

View File

@ -119,7 +119,7 @@ client = OpenAIWrapper()
# ChatCompletion
response = client.create(messages=[{"role": "user", "content": "2+2="}], model="gpt-3.5-turbo")
# extract the response text
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
# get cost of this completion
print(response.cost)
# Azure OpenAI endpoint
@ -127,7 +127,7 @@ client = OpenAIWrapper(api_key=..., base_url=..., api_version=..., api_type="azu
# Completion
response = client.create(prompt="2+2=", model="gpt-3.5-turbo-instruct")
# extract the response text
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
```
@ -240,7 +240,7 @@ Another type of error is that the returned response does not satisfy a requireme
```python
def valid_json_filter(response, **_):
for text in OpenAIWrapper.extract_text_or_function_call(response):
for text in OpenAIWrapper.extract_text_or_completion_object(response):
try:
json.loads(text)
return True