diff --git a/haystack/nodes/prompt/invocation_layer/open_ai.py b/haystack/nodes/prompt/invocation_layer/open_ai.py index 759fd7a92..85bae198b 100644 --- a/haystack/nodes/prompt/invocation_layer/open_ai.py +++ b/haystack/nodes/prompt/invocation_layer/open_ai.py @@ -7,11 +7,13 @@ import sseclient from haystack.errors import OpenAIError from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters from haystack.utils.openai_utils import ( - openai_request, _openai_text_completion_tokenization_details, load_openai_tokenizer, _check_openai_finish_reason, + check_openai_async_policy_violation, check_openai_policy_violation, + openai_async_request, + openai_request, ) from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler @@ -112,17 +114,13 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): headers["OpenAI-Organization"] = self.openai_organization return headers - def invoke(self, *args, **kwargs): - """ - Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages) - and returns a list of responses using a REST invocation. - - :return: The responses are being returned. - - Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored. - For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create). - """ + def _prepare_invoke(self, *args, **kwargs): prompt = kwargs.get("prompt") + if not prompt: + raise ValueError( + f"No prompt provided. Model {self.model_name_or_path} requires prompt." + f"Make sure to provide prompt in kwargs." + ) # either stream is True (will use default handler) or stream_handler is provided kwargs_with_defaults = self.model_input_kwargs if kwargs: @@ -150,23 +148,25 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): "frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0), "logit_bias": kwargs_with_defaults.get("logit_bias", {}), } + + return (prompt, base_payload, kwargs_with_defaults, stream, moderation) + + def invoke(self, *args, **kwargs): + """ + Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages) + and returns a list of responses using a REST invocation. + + :return: The responses are being returned. + + Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored. + For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create). + """ + prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs) + if moderation and check_openai_policy_violation(input=prompt, headers=self.headers): logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt) return [] - responses = self._execute_openai_request( - prompt=prompt, base_payload=base_payload, kwargs_with_defaults=kwargs_with_defaults, stream=stream - ) - if moderation and check_openai_policy_violation(input=responses, headers=self.headers): - logger.info("Response '%s' will not be returned due to potential policy violation.", responses) - return [] - return responses - def _execute_openai_request(self, prompt: str, base_payload: Dict, kwargs_with_defaults: Dict, stream: bool): - if not prompt: - raise ValueError( - f"No prompt provided. Model {self.model_name_or_path} requires prompt." - f"Make sure to provide prompt in kwargs." - ) extra_payload = { "prompt": prompt, "suffix": kwargs_with_defaults.get("suffix", None), @@ -179,13 +179,52 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): res = openai_request(url=self.url, headers=self.headers, payload=payload) _check_openai_finish_reason(result=res, payload=payload) responses = [ans["text"].strip() for ans in res["choices"]] - return responses else: response = openai_request( url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True ) handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) - return self._process_streaming_response(response=response, stream_handler=handler) + responses = self._process_streaming_response(response=response, stream_handler=handler) + + if moderation and check_openai_policy_violation(input=responses, headers=self.headers): + logger.info("Response '%s' will not be returned due to potential policy violation.", responses) + return [] + + return responses + + async def ainvoke(self, *args, **kwargs): + """ + asyncio version of the `invoke` method. + """ + prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs) + if moderation and await check_openai_async_policy_violation(input=prompt, headers=self.headers): + logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt) + return [] + + extra_payload = { + "prompt": prompt, + "suffix": kwargs_with_defaults.get("suffix", None), + "logprobs": kwargs_with_defaults.get("logprobs", None), + "echo": kwargs_with_defaults.get("echo", False), + "best_of": kwargs_with_defaults.get("best_of", 1), + } + payload = {**base_payload, **extra_payload} + if not stream: + res = await openai_async_request(url=self.url, headers=self.headers, payload=payload) + _check_openai_finish_reason(result=res, payload=payload) + responses = [ans["text"].strip() for ans in res["choices"]] + else: + response = await openai_async_request( + url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True + ) + handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) + responses = self._process_streaming_response(response=response, stream_handler=handler) + + if moderation and await check_openai_async_policy_violation(input=responses, headers=self.headers): + logger.info("Response '%s' will not be returned due to potential policy violation.", responses) + return [] + + return responses def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler): client = sseclient.SSEClient(response) diff --git a/haystack/nodes/prompt/prompt_model.py b/haystack/nodes/prompt/prompt_model.py index c4c903089..100a2b46e 100644 --- a/haystack/nodes/prompt/prompt_model.py +++ b/haystack/nodes/prompt/prompt_model.py @@ -115,11 +115,11 @@ class PromptModel(BaseComponent): """ Drop-in replacement asyncio version of the `invoke` method, see there for documentation. """ - try: - return await self.model_invocation_layer.invoke(prompt=prompt, **kwargs) - except TypeError: - # The `invoke` method of the underlying invocation layer doesn't support asyncio - return self.model_invocation_layer.invoke(prompt=prompt, **kwargs) + if hasattr(self.model_invocation_layer, "ainvoke"): + return await self.model_invocation_layer.ainvoke(prompt=prompt, **kwargs) + + # The underlying invocation layer doesn't support asyncio + return self.model_invocation_layer.invoke(prompt=prompt, **kwargs) @overload def _ensure_token_limit(self, prompt: str) -> str: diff --git a/haystack/utils/openai_utils.py b/haystack/utils/openai_utils.py index fdde7306c..439d04598 100644 --- a/haystack/utils/openai_utils.py +++ b/haystack/utils/openai_utils.py @@ -3,7 +3,9 @@ import os import logging import platform import json -from typing import Dict, Union, Tuple, Optional, List +from typing import Dict, Union, Tuple, Optional, List, cast + +import httpx import requests import tenacity import tiktoken @@ -143,6 +145,53 @@ def openai_request( return response +@tenacity.retry( + reraise=True, + retry=tenacity.retry_if_exception_type(OpenAIError) + and tenacity.retry_if_not_exception_type(OpenAIUnauthorizedError), + wait=tenacity.wait_exponential(multiplier=OPENAI_BACKOFF), + stop=tenacity.stop_after_attempt(OPENAI_MAX_RETRIES), +) +async def openai_async_request( + url: str, + headers: Dict, + payload: Dict, + timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT, + read_response: bool = True, + **kwargs, +): + """Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`. + + See `openai_request`. + """ + async with httpx.AsyncClient() as client: + response = await client.request( + "POST", url, headers=headers, json=payload, timeout=cast(float, timeout), **kwargs + ) + + if read_response: + json_response = json.loads(response.text) + + if response.status_code != 200: + openai_error: OpenAIError + if response.status_code == 429: + openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}") + elif response.status_code == 401: + openai_error = OpenAIUnauthorizedError(f"API key is invalid: {response.text}") + else: + openai_error = OpenAIError( + f"OpenAI returned an error.\n" + f"Status code: {response.status_code}\n" + f"Response body: {response.text}", + status_code=response.status_code, + ) + raise openai_error + if read_response: + return json_response + else: + return response + + def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -> bool: """ Calls the moderation endpoint to check if the text(s) violate the policy. @@ -163,6 +212,26 @@ def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) - return flagged +async def check_openai_async_policy_violation(input: Union[List[str], str], headers: Dict) -> bool: + """ + Calls the moderation endpoint to check if the text(s) violate the policy. + See [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation) for more details. + Returns true if any of the input is flagged as any of ['sexual', 'hate', 'violence', 'self-harm', 'sexual/minors', 'hate/threatening', 'violence/graphic']. + """ + response = await openai_async_request(url=OPENAI_MODERATION_URL, headers=headers, payload={"input": input}) + results = response["results"] + flagged = any(res["flagged"] for res in results) + if flagged: + for result in results: + if result["flagged"]: + logger.debug( + "OpenAI Moderation API flagged the text '%s' as a potential policy violation of the following categories: %s", + input, + result["categories"], + ) + return flagged + + def _check_openai_finish_reason(result: Dict, payload: Dict) -> None: """Check the `finish_reason` the answers returned by OpenAI completions endpoint. If the `finish_reason` is `length` or `content_filter`, log a warning to the user. diff --git a/pyproject.toml b/pyproject.toml index a721ee914..3c11d398c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ classifiers = [ ] dependencies = [ "requests", + "httpx", "pydantic<2", "transformers==4.32.1", "pandas", diff --git a/releasenotes/notes/add-openai-async-1f65701142f77181.yaml b/releasenotes/notes/add-openai-async-1f65701142f77181.yaml new file mode 100644 index 000000000..7d97d8064 --- /dev/null +++ b/releasenotes/notes/add-openai-async-1f65701142f77181.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add asyncio support to the OpenAI invocation layer. diff --git a/test/prompt/invocation_layer/test_chatgpt.py b/test/prompt/invocation_layer/test_chatgpt.py index 903635390..8c6fb1b40 100644 --- a/test/prompt/invocation_layer/test_chatgpt.py +++ b/test/prompt/invocation_layer/test_chatgpt.py @@ -7,7 +7,7 @@ from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer @pytest.mark.unit -@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request") +@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") def test_default_api_base(mock_request): with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key") @@ -19,7 +19,7 @@ def test_default_api_base(mock_request): @pytest.mark.unit -@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request") +@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") def test_custom_api_base(mock_request): with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com") diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index c0ae8480e..6cd4bd1a2 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -1046,36 +1046,59 @@ def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model): @pytest.mark.unit @patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None) @patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt) -def test_content_moderation_gpt_3_and_gpt_3_5(): +def test_content_moderation_gpt_3(): """ - Check all possible cases of the moderation checks passing / failing in a PromptNode call - for both ChatGPTInvocationLayer and OpenAIInvocationLayer. + Check all possible cases of the moderation checks passing / failing in a PromptNode uses + OpenAIInvocationLayer. """ - prompt_node_gpt_3_5 = PromptNode( - model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True} - ) - prompt_node_gpt_3 = PromptNode( + prompt_node = PromptNode( model_name_or_path="text-davinci-003", api_key="key", model_kwargs={"moderate_content": True} ) with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch( - "haystack.nodes.prompt.invocation_layer.chatgpt.ChatGPTInvocationLayer._execute_openai_request" - ) as mock_execute_gpt_3_5, patch( - "haystack.nodes.prompt.invocation_layer.open_ai.OpenAIInvocationLayer._execute_openai_request" - ) as mock_execute_gpt_3: + "haystack.nodes.prompt.invocation_layer.open_ai.openai_request" + ) as mock_request: VIOLENT_TEXT = "some violent text" mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT] # case 1: prompt fails the moderation check # prompt should not be sent to OpenAi & function should return an empty list mock_check.return_value = True - assert prompt_node_gpt_3_5(VIOLENT_TEXT) == prompt_node_gpt_3(VIOLENT_TEXT) == [] + assert prompt_node(VIOLENT_TEXT) == [] # case 2: prompt passes the moderation check but the generated output fails the check # function should also return an empty list - mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = [VIOLENT_TEXT] - assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == [] + mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]} + assert prompt_node("normal prompt") == [] # case 3: both prompt and output pass the moderation check # function should return the output - mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = ["normal output"] - assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == ["normal output"] + mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]} + assert prompt_node("normal prompt") == ["normal output"] + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None) +@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt) +def test_content_moderation_gpt_35(): + """ + Check all possible cases of the moderation checks passing / failing in a PromptNode uses + ChatGPTInvocationLayer. + """ + prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True}) + with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch( + "haystack.nodes.prompt.invocation_layer.open_ai.openai_request" + ) as mock_request: + VIOLENT_TEXT = "some violent text" + mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT] + # case 1: prompt fails the moderation check + # prompt should not be sent to OpenAi & function should return an empty list + mock_check.return_value = True + assert prompt_node(VIOLENT_TEXT) == [] + # case 2: prompt passes the moderation check but the generated output fails the check + # function should also return an empty list + mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]} + assert prompt_node("normal prompt") == [] + # case 3: both prompt and output pass the moderation check + # function should return the output + mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]} + assert prompt_node("normal prompt") == ["normal output"] @patch("haystack.nodes.prompt.prompt_node.PromptModel")