mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 20:46:31 +00:00
fix: fix ChatGPT invocation layer (and add async support) (#5979)
* ChatGPT async * release note * fix tests
This commit is contained in:
parent
282419d82b
commit
ccc9f010bb
@ -1,10 +1,17 @@
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Union, Any
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
|
||||
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters
|
||||
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages
|
||||
from haystack.utils.openai_utils import (
|
||||
_check_openai_finish_reason,
|
||||
check_openai_async_policy_violation,
|
||||
check_openai_policy_violation,
|
||||
count_openai_tokens_messages,
|
||||
openai_async_request,
|
||||
openai_request,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -43,45 +50,6 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
||||
"""
|
||||
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
|
||||
|
||||
def _execute_openai_request(
|
||||
self, prompt: Union[str, List[Dict]], base_payload: Dict, kwargs_with_defaults: Dict, stream: bool
|
||||
):
|
||||
"""
|
||||
For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat).
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
|
||||
messages = prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The prompt format is different than what the model expects. "
|
||||
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
|
||||
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
|
||||
)
|
||||
extra_payload = {"messages": messages}
|
||||
payload = {**base_payload, **extra_payload}
|
||||
if not stream:
|
||||
response = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_finish_reason(result=response, payload=payload)
|
||||
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
|
||||
else:
|
||||
response = openai_request(
|
||||
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||
)
|
||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)
|
||||
|
||||
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
|
||||
# We want to exclude it to be consistent with other invocation layers
|
||||
if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None:
|
||||
stop_words = kwargs_with_defaults["stop"]
|
||||
for idx, _ in enumerate(assistant_response):
|
||||
for stop_word in stop_words:
|
||||
assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip()
|
||||
|
||||
return assistant_response
|
||||
|
||||
def _extract_token(self, event_data: Dict[str, Any]):
|
||||
delta = event_data["choices"][0]["delta"]
|
||||
if "content" in delta:
|
||||
@ -141,3 +109,109 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
||||
and not "gpt-3.5-turbo-instruct" in model_name_or_path
|
||||
)
|
||||
return valid_model and not has_azure_parameters(**kwargs)
|
||||
|
||||
async def ainvoke(self, *args, **kwargs):
|
||||
"""
|
||||
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
|
||||
and returns a list of responses using a REST invocation.
|
||||
|
||||
:return: The responses are being returned.
|
||||
|
||||
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
|
||||
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
||||
"""
|
||||
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)
|
||||
|
||||
if moderation and await check_openai_async_policy_violation(input=prompt, headers=self.headers):
|
||||
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||
return []
|
||||
|
||||
if isinstance(prompt, str):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
|
||||
messages = prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The prompt format is different than what the model expects. "
|
||||
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
|
||||
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
|
||||
)
|
||||
extra_payload = {"messages": messages}
|
||||
payload = {**base_payload, **extra_payload}
|
||||
if not stream:
|
||||
response = await openai_async_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_finish_reason(result=response, payload=payload)
|
||||
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
|
||||
else:
|
||||
response = await openai_async_request(
|
||||
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||
)
|
||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)
|
||||
|
||||
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
|
||||
# We want to exclude it to be consistent with other invocation layers
|
||||
if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None:
|
||||
stop_words = kwargs_with_defaults["stop"]
|
||||
for idx, _ in enumerate(assistant_response):
|
||||
for stop_word in stop_words:
|
||||
assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip()
|
||||
|
||||
if moderation and await check_openai_async_policy_violation(input=assistant_response, headers=self.headers):
|
||||
logger.info("Response '%s' will not be returned due to potential policy violation.", assistant_response)
|
||||
return []
|
||||
|
||||
return assistant_response
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
|
||||
and returns a list of responses using a REST invocation.
|
||||
|
||||
:return: The responses are being returned.
|
||||
|
||||
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
|
||||
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
||||
"""
|
||||
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)
|
||||
|
||||
if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
|
||||
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||
return []
|
||||
|
||||
if isinstance(prompt, str):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
|
||||
messages = prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The prompt format is different than what the model expects. "
|
||||
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
|
||||
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
|
||||
)
|
||||
extra_payload = {"messages": messages}
|
||||
payload = {**base_payload, **extra_payload}
|
||||
if not stream:
|
||||
response = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_finish_reason(result=response, payload=payload)
|
||||
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
|
||||
else:
|
||||
response = openai_request(
|
||||
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||
)
|
||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)
|
||||
|
||||
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
|
||||
# We want to exclude it to be consistent with other invocation layers
|
||||
if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None:
|
||||
stop_words = kwargs_with_defaults["stop"]
|
||||
for idx, _ in enumerate(assistant_response):
|
||||
for stop_word in stop_words:
|
||||
assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip()
|
||||
|
||||
if moderation and check_openai_policy_violation(input=assistant_response, headers=self.headers):
|
||||
logger.info("Response '%s' will not be returned due to potential policy violation.", assistant_response)
|
||||
return []
|
||||
|
||||
return assistant_response
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Fixed the bug that prevented the correct usage of ChatGPT invocation layer
|
||||
in 1.21.1.
|
||||
Added async support for ChatGPT invocation layer.
|
||||
@ -1,13 +1,13 @@
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
import logging
|
||||
import pytest
|
||||
|
||||
from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
|
||||
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
|
||||
def test_default_api_base(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key")
|
||||
@ -19,7 +19,7 @@ def test_default_api_base(mock_request):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
|
||||
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
|
||||
def test_custom_api_base(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Union, List, Dict, Any, Tuple
|
||||
from unittest.mock import patch, Mock, MagicMock, AsyncMock
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from prompthub import Prompt
|
||||
|
||||
from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
||||
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||
from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES
|
||||
from haystack import BaseComponent, Document, MultiLabel, Pipeline
|
||||
from haystack.nodes.prompt import PromptModel, PromptNode, PromptTemplate
|
||||
from haystack.nodes.prompt.invocation_layer import (
|
||||
AzureChatGPTInvocationLayer,
|
||||
AzureOpenAIInvocationLayer,
|
||||
OpenAIInvocationLayer,
|
||||
ChatGPTInvocationLayer,
|
||||
OpenAIInvocationLayer,
|
||||
)
|
||||
from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -1082,8 +1082,8 @@ def test_content_moderation_gpt_35():
|
||||
ChatGPTInvocationLayer.
|
||||
"""
|
||||
prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True})
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
|
||||
"haystack.nodes.prompt.invocation_layer.open_ai.openai_request"
|
||||
with patch("haystack.nodes.prompt.invocation_layer.chatgpt.check_openai_policy_violation") as mock_check, patch(
|
||||
"haystack.nodes.prompt.invocation_layer.chatgpt.openai_request"
|
||||
) as mock_request:
|
||||
VIOLENT_TEXT = "some violent text"
|
||||
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
|
||||
@ -1093,11 +1093,15 @@ def test_content_moderation_gpt_35():
|
||||
assert prompt_node(VIOLENT_TEXT) == []
|
||||
# case 2: prompt passes the moderation check but the generated output fails the check
|
||||
# function should also return an empty list
|
||||
mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]}
|
||||
mock_request.return_value = {
|
||||
"choices": [{"message": {"content": VIOLENT_TEXT, "role": "assistant"}, "finish_reason": ""}]
|
||||
}
|
||||
assert prompt_node("normal prompt") == []
|
||||
# case 3: both prompt and output pass the moderation check
|
||||
# function should return the output
|
||||
mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]}
|
||||
mock_request.return_value = {
|
||||
"choices": [{"message": {"content": "normal output", "role": "assistant"}, "finish_reason": ""}]
|
||||
}
|
||||
assert prompt_node("normal prompt") == ["normal output"]
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user