mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-17 02:25:44 +00:00
Convert ChatCompletionMessage to Dict after completion (#791)
* update * update * update signature * update * update * fix test funccall groupchat * reverse change * update * update * update * update * update --------- Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
a31b240100
commit
9cec541630
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
|||||||
- name: Coverage
|
- name: Coverage
|
||||||
if: matrix.python-version == '3.10'
|
if: matrix.python-version == '3.10'
|
||||||
run: |
|
run: |
|
||||||
pip install -e .[mathchat,test]
|
pip install -e .[test]
|
||||||
pip uninstall -y openai
|
pip uninstall -y openai
|
||||||
coverage run -a -m pytest test --ignore=test/agentchat/contrib
|
coverage run -a -m pytest test --ignore=test/agentchat/contrib
|
||||||
coverage xml
|
coverage xml
|
||||||
|
|||||||
@ -403,7 +403,7 @@ Rules:
|
|||||||
print(colored(f"Failed to compress the content due to {e}", "red"), flush=True)
|
print(colored(f"Failed to compress the content due to {e}", "red"), flush=True)
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
compressed_message = self.client.extract_text_or_function_call(response)[0]
|
compressed_message = self.client.extract_text_or_completion_object(response)[0]
|
||||||
assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}"
|
assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}"
|
||||||
if self.compress_config["verbose"]:
|
if self.compress_config["verbose"]:
|
||||||
print(
|
print(
|
||||||
|
|||||||
@ -631,7 +631,12 @@ class ConversableAgent(Agent):
|
|||||||
response = client.create(
|
response = client.create(
|
||||||
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages
|
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages
|
||||||
)
|
)
|
||||||
return True, client.extract_text_or_function_call(response)[0]
|
|
||||||
|
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
|
||||||
|
extracted_response = client.extract_text_or_completion_object(response)[0]
|
||||||
|
if not isinstance(extracted_response, str):
|
||||||
|
extracted_response = extracted_response.model_dump(mode="dict")
|
||||||
|
return True, extracted_response
|
||||||
|
|
||||||
async def a_generate_oai_reply(
|
async def a_generate_oai_reply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -10,7 +10,9 @@ from flaml.automl.logger import logger_formatter
|
|||||||
from autogen.oai.openai_utils import get_key, oai_price1k
|
from autogen.oai.openai_utils import get_key, oai_price1k
|
||||||
from autogen.token_count_utils import count_token
|
from autogen.token_count_utils import count_token
|
||||||
|
|
||||||
|
TOOL_ENABLED = False
|
||||||
try:
|
try:
|
||||||
|
import openai
|
||||||
from openai import OpenAI, APIError
|
from openai import OpenAI, APIError
|
||||||
from openai.types.chat import ChatCompletion
|
from openai.types.chat import ChatCompletion
|
||||||
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
||||||
@ -18,6 +20,8 @@ try:
|
|||||||
from openai.types.completion_usage import CompletionUsage
|
from openai.types.completion_usage import CompletionUsage
|
||||||
import diskcache
|
import diskcache
|
||||||
|
|
||||||
|
if openai.__version__ >= "1.1.0":
|
||||||
|
TOOL_ENABLED = True
|
||||||
ERROR = None
|
ERROR = None
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
|
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
|
||||||
@ -205,7 +209,7 @@ class OpenAIWrapper:
|
|||||||
```python
|
```python
|
||||||
def yes_or_no_filter(context, response):
|
def yes_or_no_filter(context, response):
|
||||||
return context.get("yes_or_no_choice", False) is False or any(
|
return context.get("yes_or_no_choice", False) is False or any(
|
||||||
text in ["Yes.", "No."] for text in client.extract_text_or_function_call(response)
|
text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -442,21 +446,33 @@ class OpenAIWrapper:
|
|||||||
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
|
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
|
def extract_text_or_completion_object(
|
||||||
"""Extract the text or function calls from a completion or chat response.
|
cls, response: ChatCompletion | Completion
|
||||||
|
) -> Union[List[str], List[ChatCompletionMessage]]:
|
||||||
|
"""Extract the text or ChatCompletion objects from a completion or chat response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response (ChatCompletion | Completion): The response from openai.
|
response (ChatCompletion | Completion): The response from openai.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of text or function calls in the responses.
|
A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
|
||||||
"""
|
"""
|
||||||
choices = response.choices
|
choices = response.choices
|
||||||
if isinstance(response, Completion):
|
if isinstance(response, Completion):
|
||||||
return [choice.text for choice in choices]
|
return [choice.text for choice in choices]
|
||||||
return [
|
|
||||||
choice.message if choice.message.function_call is not None else choice.message.content for choice in choices
|
if TOOL_ENABLED:
|
||||||
]
|
return [
|
||||||
|
choice.message
|
||||||
|
if choice.message.function_call is not None or choice.message.tool_calls is not None
|
||||||
|
else choice.message.content
|
||||||
|
for choice in choices
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
choice.message if choice.message.function_call is not None else choice.message.content
|
||||||
|
for choice in choices
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# TODO: logging
|
# TODO: logging
|
||||||
|
|||||||
@ -48,7 +48,7 @@ def test_eval_math_responses():
|
|||||||
functions=functions,
|
functions=functions,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
responses = client.extract_text_or_function_call(response)
|
responses = client.extract_text_or_completion_object(response)
|
||||||
print(responses[0])
|
print(responses[0])
|
||||||
function_call = responses[0].function_call
|
function_call = responses[0].function_call
|
||||||
name, arguments = function_call.name, json.loads(function_call.arguments)
|
name, arguments = function_call.name, json.loads(function_call.arguments)
|
||||||
|
|||||||
@ -2,12 +2,18 @@ import pytest
|
|||||||
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
|
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
|
||||||
from test_utils import OAI_CONFIG_LIST, KEY_LOC
|
from test_utils import OAI_CONFIG_LIST, KEY_LOC
|
||||||
|
|
||||||
|
TOOL_ENABLED = False
|
||||||
try:
|
try:
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletionMessage
|
||||||
except ImportError:
|
except ImportError:
|
||||||
skip = True
|
skip = True
|
||||||
else:
|
else:
|
||||||
skip = False
|
skip = False
|
||||||
|
import openai
|
||||||
|
|
||||||
|
if openai.__version__ >= "1.1.0":
|
||||||
|
TOOL_ENABLED = True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
@ -24,7 +30,44 @@ def test_aoai_chat_completion():
|
|||||||
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|
||||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|
||||||
print(response)
|
print(response)
|
||||||
print(client.extract_text_or_function_call(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(skip and not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
|
||||||
|
def test_oai_tool_calling_extraction():
|
||||||
|
config_list = config_list_from_json(
|
||||||
|
env_or_file=OAI_CONFIG_LIST,
|
||||||
|
file_location=KEY_LOC,
|
||||||
|
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
|
||||||
|
)
|
||||||
|
client = OpenAIWrapper(config_list=config_list)
|
||||||
|
response = client.create(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather in San Francisco?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "getCurrentWeather",
|
||||||
|
"description": "Get the weather in location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||||
|
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
@ -36,7 +79,7 @@ def test_chat_completion():
|
|||||||
client = OpenAIWrapper(config_list=config_list)
|
client = OpenAIWrapper(config_list=config_list)
|
||||||
response = client.create(messages=[{"role": "user", "content": "1+1="}])
|
response = client.create(messages=[{"role": "user", "content": "1+1="}])
|
||||||
print(response)
|
print(response)
|
||||||
print(client.extract_text_or_function_call(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
@ -45,7 +88,7 @@ def test_completion():
|
|||||||
client = OpenAIWrapper(config_list=config_list)
|
client = OpenAIWrapper(config_list=config_list)
|
||||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
|
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
|
||||||
print(response)
|
print(response)
|
||||||
print(client.extract_text_or_function_call(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
@ -96,6 +139,7 @@ def test_usage_summary():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_aoai_chat_completion()
|
test_aoai_chat_completion()
|
||||||
|
test_oai_tool_calling_extraction()
|
||||||
test_chat_completion()
|
test_chat_completion()
|
||||||
test_completion()
|
test_completion()
|
||||||
test_cost()
|
test_cost()
|
||||||
|
|||||||
@ -20,7 +20,7 @@ def test_aoai_chat_completion_stream():
|
|||||||
client = OpenAIWrapper(config_list=config_list)
|
client = OpenAIWrapper(config_list=config_list)
|
||||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
|
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
|
||||||
print(response)
|
print(response)
|
||||||
print(client.extract_text_or_function_call(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
@ -33,7 +33,7 @@ def test_chat_completion_stream():
|
|||||||
client = OpenAIWrapper(config_list=config_list)
|
client = OpenAIWrapper(config_list=config_list)
|
||||||
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
|
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
|
||||||
print(response)
|
print(response)
|
||||||
print(client.extract_text_or_function_call(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
@ -66,7 +66,7 @@ def test_chat_functions_stream():
|
|||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
print(client.extract_text_or_function_call(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||||
@ -75,7 +75,7 @@ def test_completion_stream():
|
|||||||
client = OpenAIWrapper(config_list=config_list)
|
client = OpenAIWrapper(config_list=config_list)
|
||||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
|
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
|
||||||
print(response)
|
print(response)
|
||||||
print(client.extract_text_or_function_call(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -61,7 +61,7 @@ Therefore, some changes are required for users of `pyautogen<0.2`.
|
|||||||
from autogen import OpenAIWrapper
|
from autogen import OpenAIWrapper
|
||||||
client = OpenAIWrapper(config_list=config_list)
|
client = OpenAIWrapper(config_list=config_list)
|
||||||
response = client.create(messages=[{"role": "user", "content": "2+2="}])
|
response = client.create(messages=[{"role": "user", "content": "2+2="}])
|
||||||
print(client.extract_text_or_function_call(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
```
|
```
|
||||||
- Inference parameter tuning and inference logging features are currently unavailable in `OpenAIWrapper`. Logging will be added in a future release.
|
- Inference parameter tuning and inference logging features are currently unavailable in `OpenAIWrapper`. Logging will be added in a future release.
|
||||||
Inference parameter tuning can be done via [`flaml.tune`](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).
|
Inference parameter tuning can be done via [`flaml.tune`](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).
|
||||||
|
|||||||
@ -119,7 +119,7 @@ client = OpenAIWrapper()
|
|||||||
# ChatCompletion
|
# ChatCompletion
|
||||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], model="gpt-3.5-turbo")
|
response = client.create(messages=[{"role": "user", "content": "2+2="}], model="gpt-3.5-turbo")
|
||||||
# extract the response text
|
# extract the response text
|
||||||
print(client.extract_text_or_function_call(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
# get cost of this completion
|
# get cost of this completion
|
||||||
print(response.cost)
|
print(response.cost)
|
||||||
# Azure OpenAI endpoint
|
# Azure OpenAI endpoint
|
||||||
@ -127,7 +127,7 @@ client = OpenAIWrapper(api_key=..., base_url=..., api_version=..., api_type="azu
|
|||||||
# Completion
|
# Completion
|
||||||
response = client.create(prompt="2+2=", model="gpt-3.5-turbo-instruct")
|
response = client.create(prompt="2+2=", model="gpt-3.5-turbo-instruct")
|
||||||
# extract the response text
|
# extract the response text
|
||||||
print(client.extract_text_or_function_call(response))
|
print(client.extract_text_or_completion_object(response))
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -240,7 +240,7 @@ Another type of error is that the returned response does not satisfy a requireme
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
def valid_json_filter(response, **_):
|
def valid_json_filter(response, **_):
|
||||||
for text in OpenAIWrapper.extract_text_or_function_call(response):
|
for text in OpenAIWrapper.extract_text_or_completion_object(response):
|
||||||
try:
|
try:
|
||||||
json.loads(text)
|
json.loads(text)
|
||||||
return True
|
return True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user