feat: add support for async openai calls (#5946)

* add support for async openai calls

* add actual async call

* split the async api

* ask permission

* Update haystack/utils/openai_utils.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Fix OpenAI content moderation tests

* Fix ChatGPT invocation layer tests

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
Massimiliano Pippi 2023-10-03 10:42:21 +02:00 committed by GitHub
parent 1ccf674d73
commit ac408134f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 186 additions and 50 deletions

View File

@ -7,11 +7,13 @@ import sseclient
from haystack.errors import OpenAIError
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters
from haystack.utils.openai_utils import (
openai_request,
_openai_text_completion_tokenization_details,
load_openai_tokenizer,
_check_openai_finish_reason,
check_openai_async_policy_violation,
check_openai_policy_violation,
openai_async_request,
openai_request,
)
from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
@ -112,17 +114,13 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
headers["OpenAI-Organization"] = self.openai_organization
return headers
def invoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
and returns a list of responses using a REST invocation.
:return: The responses are being returned.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
def _prepare_invoke(self, *args, **kwargs):
prompt = kwargs.get("prompt")
if not prompt:
raise ValueError(
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
# either stream is True (will use default handler) or stream_handler is provided
kwargs_with_defaults = self.model_input_kwargs
if kwargs:
@ -150,23 +148,25 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
}
return (prompt, base_payload, kwargs_with_defaults, stream, moderation)
def invoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
and returns a list of responses using a REST invocation.
:return: The responses are being returned.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)
if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
return []
responses = self._execute_openai_request(
prompt=prompt, base_payload=base_payload, kwargs_with_defaults=kwargs_with_defaults, stream=stream
)
if moderation and check_openai_policy_violation(input=responses, headers=self.headers):
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
return []
return responses
def _execute_openai_request(self, prompt: str, base_payload: Dict, kwargs_with_defaults: Dict, stream: bool):
if not prompt:
raise ValueError(
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
extra_payload = {
"prompt": prompt,
"suffix": kwargs_with_defaults.get("suffix", None),
@ -179,13 +179,52 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
res = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=res, payload=payload)
responses = [ans["text"].strip() for ans in res["choices"]]
return responses
else:
response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
return self._process_streaming_response(response=response, stream_handler=handler)
responses = self._process_streaming_response(response=response, stream_handler=handler)
if moderation and check_openai_policy_violation(input=responses, headers=self.headers):
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
return []
return responses
async def ainvoke(self, *args, **kwargs):
"""
asyncio version of the `invoke` method.
"""
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)
if moderation and await check_openai_async_policy_violation(input=prompt, headers=self.headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
return []
extra_payload = {
"prompt": prompt,
"suffix": kwargs_with_defaults.get("suffix", None),
"logprobs": kwargs_with_defaults.get("logprobs", None),
"echo": kwargs_with_defaults.get("echo", False),
"best_of": kwargs_with_defaults.get("best_of", 1),
}
payload = {**base_payload, **extra_payload}
if not stream:
res = await openai_async_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=res, payload=payload)
responses = [ans["text"].strip() for ans in res["choices"]]
else:
response = await openai_async_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
responses = self._process_streaming_response(response=response, stream_handler=handler)
if moderation and await check_openai_async_policy_violation(input=responses, headers=self.headers):
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
return []
return responses
def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler):
client = sseclient.SSEClient(response)

View File

@ -115,11 +115,11 @@ class PromptModel(BaseComponent):
"""
Drop-in replacement asyncio version of the `invoke` method, see there for documentation.
"""
try:
return await self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
except TypeError:
# The `invoke` method of the underlying invocation layer doesn't support asyncio
return self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
if hasattr(self.model_invocation_layer, "ainvoke"):
return await self.model_invocation_layer.ainvoke(prompt=prompt, **kwargs)
# The underlying invocation layer doesn't support asyncio
return self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
@overload
def _ensure_token_limit(self, prompt: str) -> str:

View File

@ -3,7 +3,9 @@ import os
import logging
import platform
import json
from typing import Dict, Union, Tuple, Optional, List
from typing import Dict, Union, Tuple, Optional, List, cast
import httpx
import requests
import tenacity
import tiktoken
@ -143,6 +145,53 @@ def openai_request(
return response
@tenacity.retry(
reraise=True,
retry=tenacity.retry_if_exception_type(OpenAIError)
and tenacity.retry_if_not_exception_type(OpenAIUnauthorizedError),
wait=tenacity.wait_exponential(multiplier=OPENAI_BACKOFF),
stop=tenacity.stop_after_attempt(OPENAI_MAX_RETRIES),
)
async def openai_async_request(
url: str,
headers: Dict,
payload: Dict,
timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT,
read_response: bool = True,
**kwargs,
):
"""Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`.
See `openai_request`.
"""
async with httpx.AsyncClient() as client:
response = await client.request(
"POST", url, headers=headers, json=payload, timeout=cast(float, timeout), **kwargs
)
if read_response:
json_response = json.loads(response.text)
if response.status_code != 200:
openai_error: OpenAIError
if response.status_code == 429:
openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
elif response.status_code == 401:
openai_error = OpenAIUnauthorizedError(f"API key is invalid: {response.text}")
else:
openai_error = OpenAIError(
f"OpenAI returned an error.\n"
f"Status code: {response.status_code}\n"
f"Response body: {response.text}",
status_code=response.status_code,
)
raise openai_error
if read_response:
return json_response
else:
return response
def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
"""
Calls the moderation endpoint to check if the text(s) violate the policy.
@ -163,6 +212,26 @@ def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -
return flagged
async def check_openai_async_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
"""
Calls the moderation endpoint to check if the text(s) violate the policy.
See [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation) for more details.
Returns true if any of the input is flagged as any of ['sexual', 'hate', 'violence', 'self-harm', 'sexual/minors', 'hate/threatening', 'violence/graphic'].
"""
response = await openai_async_request(url=OPENAI_MODERATION_URL, headers=headers, payload={"input": input})
results = response["results"]
flagged = any(res["flagged"] for res in results)
if flagged:
for result in results:
if result["flagged"]:
logger.debug(
"OpenAI Moderation API flagged the text '%s' as a potential policy violation of the following categories: %s",
input,
result["categories"],
)
return flagged
def _check_openai_finish_reason(result: Dict, payload: Dict) -> None:
"""Check the `finish_reason` the answers returned by OpenAI completions endpoint.
If the `finish_reason` is `length` or `content_filter`, log a warning to the user.

View File

@ -47,6 +47,7 @@ classifiers = [
]
dependencies = [
"requests",
"httpx",
"pydantic<2",
"transformers==4.32.1",
"pandas",

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Add asyncio support to the OpenAI invocation layer.

View File

@ -7,7 +7,7 @@ from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_default_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key")
@ -19,7 +19,7 @@ def test_default_api_base(mock_request):
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_custom_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")

View File

@ -1046,36 +1046,59 @@ def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
def test_content_moderation_gpt_3_and_gpt_3_5():
def test_content_moderation_gpt_3():
"""
Check all possible cases of the moderation checks passing / failing in a PromptNode call
for both ChatGPTInvocationLayer and OpenAIInvocationLayer.
Check all possible cases of the moderation checks passing / failing in a PromptNode uses
OpenAIInvocationLayer.
"""
prompt_node_gpt_3_5 = PromptNode(
model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True}
)
prompt_node_gpt_3 = PromptNode(
prompt_node = PromptNode(
model_name_or_path="text-davinci-003", api_key="key", model_kwargs={"moderate_content": True}
)
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
"haystack.nodes.prompt.invocation_layer.chatgpt.ChatGPTInvocationLayer._execute_openai_request"
) as mock_execute_gpt_3_5, patch(
"haystack.nodes.prompt.invocation_layer.open_ai.OpenAIInvocationLayer._execute_openai_request"
) as mock_execute_gpt_3:
"haystack.nodes.prompt.invocation_layer.open_ai.openai_request"
) as mock_request:
VIOLENT_TEXT = "some violent text"
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
# case 1: prompt fails the moderation check
# prompt should not be sent to OpenAi & function should return an empty list
mock_check.return_value = True
assert prompt_node_gpt_3_5(VIOLENT_TEXT) == prompt_node_gpt_3(VIOLENT_TEXT) == []
assert prompt_node(VIOLENT_TEXT) == []
# case 2: prompt passes the moderation check but the generated output fails the check
# function should also return an empty list
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = [VIOLENT_TEXT]
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == []
mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]}
assert prompt_node("normal prompt") == []
# case 3: both prompt and output pass the moderation check
# function should return the output
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = ["normal output"]
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == ["normal output"]
mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]}
assert prompt_node("normal prompt") == ["normal output"]
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
def test_content_moderation_gpt_35():
"""
Check all possible cases of the moderation checks passing / failing in a PromptNode uses
ChatGPTInvocationLayer.
"""
prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True})
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
"haystack.nodes.prompt.invocation_layer.open_ai.openai_request"
) as mock_request:
VIOLENT_TEXT = "some violent text"
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
# case 1: prompt fails the moderation check
# prompt should not be sent to OpenAi & function should return an empty list
mock_check.return_value = True
assert prompt_node(VIOLENT_TEXT) == []
# case 2: prompt passes the moderation check but the generated output fails the check
# function should also return an empty list
mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]}
assert prompt_node("normal prompt") == []
# case 3: both prompt and output pass the moderation check
# function should return the output
mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]}
assert prompt_node("normal prompt") == ["normal output"]
@patch("haystack.nodes.prompt.prompt_node.PromptModel")