diff --git a/haystack/nodes/prompt/invocation_layer/chatgpt.py b/haystack/nodes/prompt/invocation_layer/chatgpt.py index 371b86d6a..f3e1a3ef6 100644 --- a/haystack/nodes/prompt/invocation_layer/chatgpt.py +++ b/haystack/nodes/prompt/invocation_layer/chatgpt.py @@ -1,10 +1,17 @@ import logging -from typing import Optional, List, Dict, Union, Any +from typing import Any, Dict, List, Optional, Union from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters -from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages +from haystack.utils.openai_utils import ( + _check_openai_finish_reason, + check_openai_async_policy_violation, + check_openai_policy_violation, + count_openai_tokens_messages, + openai_async_request, + openai_request, +) logger = logging.getLogger(__name__) @@ -43,45 +50,6 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer): """ super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs) - def _execute_openai_request( - self, prompt: Union[str, List[Dict]], base_payload: Dict, kwargs_with_defaults: Dict, stream: bool - ): - """ - For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat). - """ - if isinstance(prompt, str): - messages = [{"role": "user", "content": prompt}] - elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict): - messages = prompt - else: - raise ValueError( - f"The prompt format is different than what the model expects. " - f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. " - f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)." - ) - extra_payload = {"messages": messages} - payload = {**base_payload, **extra_payload} - if not stream: - response = openai_request(url=self.url, headers=self.headers, payload=payload) - _check_openai_finish_reason(result=response, payload=payload) - assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]] - else: - response = openai_request( - url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True - ) - handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) - assistant_response = self._process_streaming_response(response=response, stream_handler=handler) - - # Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word - # We want to exclude it to be consistent with other invocation layers - if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None: - stop_words = kwargs_with_defaults["stop"] - for idx, _ in enumerate(assistant_response): - for stop_word in stop_words: - assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip() - - return assistant_response - def _extract_token(self, event_data: Dict[str, Any]): delta = event_data["choices"][0]["delta"] if "content" in delta: @@ -141,3 +109,109 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer): and not "gpt-3.5-turbo-instruct" in model_name_or_path ) return valid_model and not has_azure_parameters(**kwargs) + + async def ainvoke(self, *args, **kwargs): + """ + Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages) + and returns a list of responses using a REST invocation. + + :return: The responses are being returned. + + Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored. + For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create). + """ + prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs) + + if moderation and await check_openai_async_policy_violation(input=prompt, headers=self.headers): + logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt) + return [] + + if isinstance(prompt, str): + messages = [{"role": "user", "content": prompt}] + elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict): + messages = prompt + else: + raise ValueError( + f"The prompt format is different than what the model expects. " + f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. " + f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)." + ) + extra_payload = {"messages": messages} + payload = {**base_payload, **extra_payload} + if not stream: + response = await openai_async_request(url=self.url, headers=self.headers, payload=payload) + _check_openai_finish_reason(result=response, payload=payload) + assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]] + else: + response = await openai_async_request( + url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True + ) + handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) + assistant_response = self._process_streaming_response(response=response, stream_handler=handler) + + # Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word + # We want to exclude it to be consistent with other invocation layers + if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None: + stop_words = kwargs_with_defaults["stop"] + for idx, _ in enumerate(assistant_response): + for stop_word in stop_words: + assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip() + + if moderation and await check_openai_async_policy_violation(input=assistant_response, headers=self.headers): + logger.info("Response '%s' will not be returned due to potential policy violation.", assistant_response) + return [] + + return assistant_response + + def invoke(self, *args, **kwargs): + """ + Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages) + and returns a list of responses using a REST invocation. + + :return: The responses are being returned. + + Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored. + For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create). + """ + prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs) + + if moderation and check_openai_policy_violation(input=prompt, headers=self.headers): + logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt) + return [] + + if isinstance(prompt, str): + messages = [{"role": "user", "content": prompt}] + elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict): + messages = prompt + else: + raise ValueError( + f"The prompt format is different than what the model expects. " + f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. " + f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)." + ) + extra_payload = {"messages": messages} + payload = {**base_payload, **extra_payload} + if not stream: + response = openai_request(url=self.url, headers=self.headers, payload=payload) + _check_openai_finish_reason(result=response, payload=payload) + assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]] + else: + response = openai_request( + url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True + ) + handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) + assistant_response = self._process_streaming_response(response=response, stream_handler=handler) + + # Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word + # We want to exclude it to be consistent with other invocation layers + if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None: + stop_words = kwargs_with_defaults["stop"] + for idx, _ in enumerate(assistant_response): + for stop_word in stop_words: + assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip() + + if moderation and check_openai_policy_violation(input=assistant_response, headers=self.headers): + logger.info("Response '%s' will not be returned due to potential policy violation.", assistant_response) + return [] + + return assistant_response diff --git a/releasenotes/notes/fix-chatgpt-invocation-layer-bc25d0ea5f77f05c.yaml b/releasenotes/notes/fix-chatgpt-invocation-layer-bc25d0ea5f77f05c.yaml new file mode 100644 index 000000000..2be3084dd --- /dev/null +++ b/releasenotes/notes/fix-chatgpt-invocation-layer-bc25d0ea5f77f05c.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + Fixed the bug that prevented the correct usage of ChatGPT invocation layer + in 1.21.1. + Added async support for ChatGPT invocation layer. diff --git a/test/prompt/invocation_layer/test_chatgpt.py b/test/prompt/invocation_layer/test_chatgpt.py index 8c6fb1b40..c1b816d49 100644 --- a/test/prompt/invocation_layer/test_chatgpt.py +++ b/test/prompt/invocation_layer/test_chatgpt.py @@ -1,13 +1,13 @@ +import logging from unittest.mock import patch -import logging import pytest from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer @pytest.mark.unit -@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") +@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request") def test_default_api_base(mock_request): with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key") @@ -19,7 +19,7 @@ def test_default_api_base(mock_request): @pytest.mark.unit -@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") +@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request") def test_custom_api_base(mock_request): with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com") diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index 6cd4bd1a2..972a04be1 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -1,20 +1,20 @@ -import os import logging -from typing import Optional, Union, List, Dict, Any, Tuple -from unittest.mock import patch, Mock, MagicMock, AsyncMock +import os +from typing import Any, Dict, List, Optional, Tuple, Union +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from prompthub import Prompt -from haystack import Document, Pipeline, BaseComponent, MultiLabel -from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel -from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES +from haystack import BaseComponent, Document, MultiLabel, Pipeline +from haystack.nodes.prompt import PromptModel, PromptNode, PromptTemplate from haystack.nodes.prompt.invocation_layer import ( AzureChatGPTInvocationLayer, AzureOpenAIInvocationLayer, - OpenAIInvocationLayer, ChatGPTInvocationLayer, + OpenAIInvocationLayer, ) +from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES @pytest.fixture @@ -1082,8 +1082,8 @@ def test_content_moderation_gpt_35(): ChatGPTInvocationLayer. """ prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True}) - with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch( - "haystack.nodes.prompt.invocation_layer.open_ai.openai_request" + with patch("haystack.nodes.prompt.invocation_layer.chatgpt.check_openai_policy_violation") as mock_check, patch( + "haystack.nodes.prompt.invocation_layer.chatgpt.openai_request" ) as mock_request: VIOLENT_TEXT = "some violent text" mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT] @@ -1093,11 +1093,15 @@ def test_content_moderation_gpt_35(): assert prompt_node(VIOLENT_TEXT) == [] # case 2: prompt passes the moderation check but the generated output fails the check # function should also return an empty list - mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]} + mock_request.return_value = { + "choices": [{"message": {"content": VIOLENT_TEXT, "role": "assistant"}, "finish_reason": ""}] + } assert prompt_node("normal prompt") == [] # case 3: both prompt and output pass the moderation check # function should return the output - mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]} + mock_request.return_value = { + "choices": [{"message": {"content": "normal output", "role": "assistant"}, "finish_reason": ""}] + } assert prompt_node("normal prompt") == ["normal output"]