fix: fix ChatGPT invocation layer (and add async support) (#5979)

* ChatGPT async

* release note

* fix tests
This commit is contained in:
Stefano Fiorucci 2023-10-05 18:43:26 +02:00 committed by GitHub
parent 282419d82b
commit ccc9f010bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 139 additions and 55 deletions

View File

@ -1,10 +1,17 @@
import logging
from typing import Optional, List, Dict, Union, Any
from typing import Any, Dict, List, Optional, Union
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages
from haystack.utils.openai_utils import (
_check_openai_finish_reason,
check_openai_async_policy_violation,
check_openai_policy_violation,
count_openai_tokens_messages,
openai_async_request,
openai_request,
)
logger = logging.getLogger(__name__)
@ -43,45 +50,6 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
"""
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
def _execute_openai_request(
self, prompt: Union[str, List[Dict]], base_payload: Dict, kwargs_with_defaults: Dict, stream: bool
):
"""
For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat).
"""
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
messages = prompt
else:
raise ValueError(
f"The prompt format is different than what the model expects. "
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
)
extra_payload = {"messages": messages}
payload = {**base_payload, **extra_payload}
if not stream:
response = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload)
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
else:
response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
# We want to exclude it to be consistent with other invocation layers
if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None:
stop_words = kwargs_with_defaults["stop"]
for idx, _ in enumerate(assistant_response):
for stop_word in stop_words:
assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip()
return assistant_response
def _extract_token(self, event_data: Dict[str, Any]):
delta = event_data["choices"][0]["delta"]
if "content" in delta:
@ -141,3 +109,109 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
and not "gpt-3.5-turbo-instruct" in model_name_or_path
)
return valid_model and not has_azure_parameters(**kwargs)
async def ainvoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
and returns a list of responses using a REST invocation.
:return: The responses are being returned.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)
if moderation and await check_openai_async_policy_violation(input=prompt, headers=self.headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
return []
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
messages = prompt
else:
raise ValueError(
f"The prompt format is different than what the model expects. "
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
)
extra_payload = {"messages": messages}
payload = {**base_payload, **extra_payload}
if not stream:
response = await openai_async_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload)
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
else:
response = await openai_async_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
# We want to exclude it to be consistent with other invocation layers
if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None:
stop_words = kwargs_with_defaults["stop"]
for idx, _ in enumerate(assistant_response):
for stop_word in stop_words:
assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip()
if moderation and await check_openai_async_policy_violation(input=assistant_response, headers=self.headers):
logger.info("Response '%s' will not be returned due to potential policy violation.", assistant_response)
return []
return assistant_response
def invoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
and returns a list of responses using a REST invocation.
:return: The responses are being returned.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)
if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
return []
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
messages = prompt
else:
raise ValueError(
f"The prompt format is different than what the model expects. "
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
)
extra_payload = {"messages": messages}
payload = {**base_payload, **extra_payload}
if not stream:
response = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload)
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
else:
response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
# We want to exclude it to be consistent with other invocation layers
if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None:
stop_words = kwargs_with_defaults["stop"]
for idx, _ in enumerate(assistant_response):
for stop_word in stop_words:
assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip()
if moderation and check_openai_policy_violation(input=assistant_response, headers=self.headers):
logger.info("Response '%s' will not be returned due to potential policy violation.", assistant_response)
return []
return assistant_response

View File

@ -0,0 +1,6 @@
---
fixes:
- |
Fixed the bug that prevented the correct usage of ChatGPT invocation layer
in 1.21.1.
Added async support for ChatGPT invocation layer.

View File

@ -1,13 +1,13 @@
import logging
from unittest.mock import patch
import logging
import pytest
from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
def test_default_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key")
@ -19,7 +19,7 @@ def test_default_api_base(mock_request):
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
def test_custom_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")

View File

@ -1,20 +1,20 @@
import os
import logging
from typing import Optional, Union, List, Dict, Any, Tuple
from unittest.mock import patch, Mock, MagicMock, AsyncMock
import os
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from prompthub import Prompt
from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES
from haystack import BaseComponent, Document, MultiLabel, Pipeline
from haystack.nodes.prompt import PromptModel, PromptNode, PromptTemplate
from haystack.nodes.prompt.invocation_layer import (
AzureChatGPTInvocationLayer,
AzureOpenAIInvocationLayer,
OpenAIInvocationLayer,
ChatGPTInvocationLayer,
OpenAIInvocationLayer,
)
from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES
@pytest.fixture
@ -1082,8 +1082,8 @@ def test_content_moderation_gpt_35():
ChatGPTInvocationLayer.
"""
prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True})
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
"haystack.nodes.prompt.invocation_layer.open_ai.openai_request"
with patch("haystack.nodes.prompt.invocation_layer.chatgpt.check_openai_policy_violation") as mock_check, patch(
"haystack.nodes.prompt.invocation_layer.chatgpt.openai_request"
) as mock_request:
VIOLENT_TEXT = "some violent text"
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
@ -1093,11 +1093,15 @@ def test_content_moderation_gpt_35():
assert prompt_node(VIOLENT_TEXT) == []
# case 2: prompt passes the moderation check but the generated output fails the check
# function should also return an empty list
mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]}
mock_request.return_value = {
"choices": [{"message": {"content": VIOLENT_TEXT, "role": "assistant"}, "finish_reason": ""}]
}
assert prompt_node("normal prompt") == []
# case 3: both prompt and output pass the moderation check
# function should return the output
mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]}
mock_request.return_value = {
"choices": [{"message": {"content": "normal output", "role": "assistant"}, "finish_reason": ""}]
}
assert prompt_node("normal prompt") == ["normal output"]