mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-02 18:59:48 +00:00
Convert ChatCompletionMessage to Dict after completion (#791)
* update * update * update signature * update * update * fix test funccall groupchat * reverse change * update * update * update * update * update --------- Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
a31b240100
commit
9cec541630
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
||||
- name: Coverage
|
||||
if: matrix.python-version == '3.10'
|
||||
run: |
|
||||
pip install -e .[mathchat,test]
|
||||
pip install -e .[test]
|
||||
pip uninstall -y openai
|
||||
coverage run -a -m pytest test --ignore=test/agentchat/contrib
|
||||
coverage xml
|
||||
|
||||
@ -403,7 +403,7 @@ Rules:
|
||||
print(colored(f"Failed to compress the content due to {e}", "red"), flush=True)
|
||||
return False, None
|
||||
|
||||
compressed_message = self.client.extract_text_or_function_call(response)[0]
|
||||
compressed_message = self.client.extract_text_or_completion_object(response)[0]
|
||||
assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}"
|
||||
if self.compress_config["verbose"]:
|
||||
print(
|
||||
|
||||
@ -631,7 +631,12 @@ class ConversableAgent(Agent):
|
||||
response = client.create(
|
||||
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages
|
||||
)
|
||||
return True, client.extract_text_or_function_call(response)[0]
|
||||
|
||||
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
|
||||
extracted_response = client.extract_text_or_completion_object(response)[0]
|
||||
if not isinstance(extracted_response, str):
|
||||
extracted_response = extracted_response.model_dump(mode="dict")
|
||||
return True, extracted_response
|
||||
|
||||
async def a_generate_oai_reply(
|
||||
self,
|
||||
|
||||
@ -10,7 +10,9 @@ from flaml.automl.logger import logger_formatter
|
||||
from autogen.oai.openai_utils import get_key, oai_price1k
|
||||
from autogen.token_count_utils import count_token
|
||||
|
||||
TOOL_ENABLED = False
|
||||
try:
|
||||
import openai
|
||||
from openai import OpenAI, APIError
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
||||
@ -18,6 +20,8 @@ try:
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
import diskcache
|
||||
|
||||
if openai.__version__ >= "1.1.0":
|
||||
TOOL_ENABLED = True
|
||||
ERROR = None
|
||||
except ImportError:
|
||||
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
|
||||
@ -205,7 +209,7 @@ class OpenAIWrapper:
|
||||
```python
|
||||
def yes_or_no_filter(context, response):
|
||||
return context.get("yes_or_no_choice", False) is False or any(
|
||||
text in ["Yes.", "No."] for text in client.extract_text_or_function_call(response)
|
||||
text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
|
||||
)
|
||||
```
|
||||
|
||||
@ -442,20 +446,32 @@ class OpenAIWrapper:
|
||||
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
|
||||
|
||||
@classmethod
|
||||
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
|
||||
"""Extract the text or function calls from a completion or chat response.
|
||||
def extract_text_or_completion_object(
|
||||
cls, response: ChatCompletion | Completion
|
||||
) -> Union[List[str], List[ChatCompletionMessage]]:
|
||||
"""Extract the text or ChatCompletion objects from a completion or chat response.
|
||||
|
||||
Args:
|
||||
response (ChatCompletion | Completion): The response from openai.
|
||||
|
||||
Returns:
|
||||
A list of text or function calls in the responses.
|
||||
A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
|
||||
"""
|
||||
choices = response.choices
|
||||
if isinstance(response, Completion):
|
||||
return [choice.text for choice in choices]
|
||||
|
||||
if TOOL_ENABLED:
|
||||
return [
|
||||
choice.message if choice.message.function_call is not None else choice.message.content for choice in choices
|
||||
choice.message
|
||||
if choice.message.function_call is not None or choice.message.tool_calls is not None
|
||||
else choice.message.content
|
||||
for choice in choices
|
||||
]
|
||||
else:
|
||||
return [
|
||||
choice.message if choice.message.function_call is not None else choice.message.content
|
||||
for choice in choices
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ def test_eval_math_responses():
|
||||
functions=functions,
|
||||
)
|
||||
print(response)
|
||||
responses = client.extract_text_or_function_call(response)
|
||||
responses = client.extract_text_or_completion_object(response)
|
||||
print(responses[0])
|
||||
function_call = responses[0].function_call
|
||||
name, arguments = function_call.name, json.loads(function_call.arguments)
|
||||
|
||||
@ -2,12 +2,18 @@ import pytest
|
||||
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
|
||||
from test_utils import OAI_CONFIG_LIST, KEY_LOC
|
||||
|
||||
TOOL_ENABLED = False
|
||||
try:
|
||||
from openai import OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage
|
||||
except ImportError:
|
||||
skip = True
|
||||
else:
|
||||
skip = False
|
||||
import openai
|
||||
|
||||
if openai.__version__ >= "1.1.0":
|
||||
TOOL_ENABLED = True
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
@ -24,7 +30,44 @@ def test_aoai_chat_completion():
|
||||
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|
||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|
||||
print(response)
|
||||
print(client.extract_text_or_function_call(response))
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip and not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
|
||||
def test_oai_tool_calling_extraction():
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
|
||||
)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in San Francisco?",
|
||||
},
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "getCurrentWeather",
|
||||
"description": "Get the weather in location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
@ -36,7 +79,7 @@ def test_chat_completion():
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(messages=[{"role": "user", "content": "1+1="}])
|
||||
print(response)
|
||||
print(client.extract_text_or_function_call(response))
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
@ -45,7 +88,7 @@ def test_completion():
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
|
||||
print(response)
|
||||
print(client.extract_text_or_function_call(response))
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
@ -96,6 +139,7 @@ def test_usage_summary():
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_aoai_chat_completion()
|
||||
test_oai_tool_calling_extraction()
|
||||
test_chat_completion()
|
||||
test_completion()
|
||||
test_cost()
|
||||
|
||||
@ -20,7 +20,7 @@ def test_aoai_chat_completion_stream():
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
|
||||
print(response)
|
||||
print(client.extract_text_or_function_call(response))
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
@ -33,7 +33,7 @@ def test_chat_completion_stream():
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
|
||||
print(response)
|
||||
print(client.extract_text_or_function_call(response))
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
@ -66,7 +66,7 @@ def test_chat_functions_stream():
|
||||
stream=True,
|
||||
)
|
||||
print(response)
|
||||
print(client.extract_text_or_function_call(response))
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
@ -75,7 +75,7 @@ def test_completion_stream():
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
|
||||
print(response)
|
||||
print(client.extract_text_or_function_call(response))
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -61,7 +61,7 @@ Therefore, some changes are required for users of `pyautogen<0.2`.
|
||||
from autogen import OpenAIWrapper
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(messages=[{"role": "user", "content": "2+2="}])
|
||||
print(client.extract_text_or_function_call(response))
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
```
|
||||
- Inference parameter tuning and inference logging features are currently unavailable in `OpenAIWrapper`. Logging will be added in a future release.
|
||||
Inference parameter tuning can be done via [`flaml.tune`](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).
|
||||
|
||||
@ -119,7 +119,7 @@ client = OpenAIWrapper()
|
||||
# ChatCompletion
|
||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], model="gpt-3.5-turbo")
|
||||
# extract the response text
|
||||
print(client.extract_text_or_function_call(response))
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
# get cost of this completion
|
||||
print(response.cost)
|
||||
# Azure OpenAI endpoint
|
||||
@ -127,7 +127,7 @@ client = OpenAIWrapper(api_key=..., base_url=..., api_version=..., api_type="azu
|
||||
# Completion
|
||||
response = client.create(prompt="2+2=", model="gpt-3.5-turbo-instruct")
|
||||
# extract the response text
|
||||
print(client.extract_text_or_function_call(response))
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
```
|
||||
|
||||
@ -240,7 +240,7 @@ Another type of error is that the returned response does not satisfy a requireme
|
||||
|
||||
```python
|
||||
def valid_json_filter(response, **_):
|
||||
for text in OpenAIWrapper.extract_text_or_function_call(response):
|
||||
for text in OpenAIWrapper.extract_text_or_completion_object(response):
|
||||
try:
|
||||
json.loads(text)
|
||||
return True
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user