mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-16 10:28:55 +00:00
feat: add support for async openai calls (#5946)
* add support for async openai calls * add actual async call * split the async api * ask permission * Update haystack/utils/openai_utils.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Fix OpenAI content moderation tests * Fix ChatGPT invocation layer tests --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
parent
1ccf674d73
commit
ac408134f4
@ -7,11 +7,13 @@ import sseclient
|
||||
from haystack.errors import OpenAIError
|
||||
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters
|
||||
from haystack.utils.openai_utils import (
|
||||
openai_request,
|
||||
_openai_text_completion_tokenization_details,
|
||||
load_openai_tokenizer,
|
||||
_check_openai_finish_reason,
|
||||
check_openai_async_policy_violation,
|
||||
check_openai_policy_violation,
|
||||
openai_async_request,
|
||||
openai_request,
|
||||
)
|
||||
from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
|
||||
@ -112,17 +114,13 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
headers["OpenAI-Organization"] = self.openai_organization
|
||||
return headers
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
|
||||
and returns a list of responses using a REST invocation.
|
||||
|
||||
:return: The responses are being returned.
|
||||
|
||||
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
|
||||
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
||||
"""
|
||||
def _prepare_invoke(self, *args, **kwargs):
|
||||
prompt = kwargs.get("prompt")
|
||||
if not prompt:
|
||||
raise ValueError(
|
||||
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
||||
f"Make sure to provide prompt in kwargs."
|
||||
)
|
||||
# either stream is True (will use default handler) or stream_handler is provided
|
||||
kwargs_with_defaults = self.model_input_kwargs
|
||||
if kwargs:
|
||||
@ -150,23 +148,25 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
||||
}
|
||||
|
||||
return (prompt, base_payload, kwargs_with_defaults, stream, moderation)
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
|
||||
and returns a list of responses using a REST invocation.
|
||||
|
||||
:return: The responses are being returned.
|
||||
|
||||
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
|
||||
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
||||
"""
|
||||
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)
|
||||
|
||||
if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
|
||||
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||
return []
|
||||
responses = self._execute_openai_request(
|
||||
prompt=prompt, base_payload=base_payload, kwargs_with_defaults=kwargs_with_defaults, stream=stream
|
||||
)
|
||||
if moderation and check_openai_policy_violation(input=responses, headers=self.headers):
|
||||
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
|
||||
return []
|
||||
return responses
|
||||
|
||||
def _execute_openai_request(self, prompt: str, base_payload: Dict, kwargs_with_defaults: Dict, stream: bool):
|
||||
if not prompt:
|
||||
raise ValueError(
|
||||
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
||||
f"Make sure to provide prompt in kwargs."
|
||||
)
|
||||
extra_payload = {
|
||||
"prompt": prompt,
|
||||
"suffix": kwargs_with_defaults.get("suffix", None),
|
||||
@ -179,13 +179,52 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
res = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_finish_reason(result=res, payload=payload)
|
||||
responses = [ans["text"].strip() for ans in res["choices"]]
|
||||
return responses
|
||||
else:
|
||||
response = openai_request(
|
||||
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||
)
|
||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
return self._process_streaming_response(response=response, stream_handler=handler)
|
||||
responses = self._process_streaming_response(response=response, stream_handler=handler)
|
||||
|
||||
if moderation and check_openai_policy_violation(input=responses, headers=self.headers):
|
||||
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
|
||||
return []
|
||||
|
||||
return responses
|
||||
|
||||
async def ainvoke(self, *args, **kwargs):
|
||||
"""
|
||||
asyncio version of the `invoke` method.
|
||||
"""
|
||||
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)
|
||||
if moderation and await check_openai_async_policy_violation(input=prompt, headers=self.headers):
|
||||
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||
return []
|
||||
|
||||
extra_payload = {
|
||||
"prompt": prompt,
|
||||
"suffix": kwargs_with_defaults.get("suffix", None),
|
||||
"logprobs": kwargs_with_defaults.get("logprobs", None),
|
||||
"echo": kwargs_with_defaults.get("echo", False),
|
||||
"best_of": kwargs_with_defaults.get("best_of", 1),
|
||||
}
|
||||
payload = {**base_payload, **extra_payload}
|
||||
if not stream:
|
||||
res = await openai_async_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_finish_reason(result=res, payload=payload)
|
||||
responses = [ans["text"].strip() for ans in res["choices"]]
|
||||
else:
|
||||
response = await openai_async_request(
|
||||
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||
)
|
||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
responses = self._process_streaming_response(response=response, stream_handler=handler)
|
||||
|
||||
if moderation and await check_openai_async_policy_violation(input=responses, headers=self.headers):
|
||||
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
|
||||
return []
|
||||
|
||||
return responses
|
||||
|
||||
def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler):
|
||||
client = sseclient.SSEClient(response)
|
||||
|
@ -115,11 +115,11 @@ class PromptModel(BaseComponent):
|
||||
"""
|
||||
Drop-in replacement asyncio version of the `invoke` method, see there for documentation.
|
||||
"""
|
||||
try:
|
||||
return await self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
|
||||
except TypeError:
|
||||
# The `invoke` method of the underlying invocation layer doesn't support asyncio
|
||||
return self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
|
||||
if hasattr(self.model_invocation_layer, "ainvoke"):
|
||||
return await self.model_invocation_layer.ainvoke(prompt=prompt, **kwargs)
|
||||
|
||||
# The underlying invocation layer doesn't support asyncio
|
||||
return self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
|
||||
|
||||
@overload
|
||||
def _ensure_token_limit(self, prompt: str) -> str:
|
||||
|
@ -3,7 +3,9 @@ import os
|
||||
import logging
|
||||
import platform
|
||||
import json
|
||||
from typing import Dict, Union, Tuple, Optional, List
|
||||
from typing import Dict, Union, Tuple, Optional, List, cast
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
import tenacity
|
||||
import tiktoken
|
||||
@ -143,6 +145,53 @@ def openai_request(
|
||||
return response
|
||||
|
||||
|
||||
@tenacity.retry(
|
||||
reraise=True,
|
||||
retry=tenacity.retry_if_exception_type(OpenAIError)
|
||||
and tenacity.retry_if_not_exception_type(OpenAIUnauthorizedError),
|
||||
wait=tenacity.wait_exponential(multiplier=OPENAI_BACKOFF),
|
||||
stop=tenacity.stop_after_attempt(OPENAI_MAX_RETRIES),
|
||||
)
|
||||
async def openai_async_request(
|
||||
url: str,
|
||||
headers: Dict,
|
||||
payload: Dict,
|
||||
timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT,
|
||||
read_response: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`.
|
||||
|
||||
See `openai_request`.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(
|
||||
"POST", url, headers=headers, json=payload, timeout=cast(float, timeout), **kwargs
|
||||
)
|
||||
|
||||
if read_response:
|
||||
json_response = json.loads(response.text)
|
||||
|
||||
if response.status_code != 200:
|
||||
openai_error: OpenAIError
|
||||
if response.status_code == 429:
|
||||
openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
|
||||
elif response.status_code == 401:
|
||||
openai_error = OpenAIUnauthorizedError(f"API key is invalid: {response.text}")
|
||||
else:
|
||||
openai_error = OpenAIError(
|
||||
f"OpenAI returned an error.\n"
|
||||
f"Status code: {response.status_code}\n"
|
||||
f"Response body: {response.text}",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
raise openai_error
|
||||
if read_response:
|
||||
return json_response
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
|
||||
"""
|
||||
Calls the moderation endpoint to check if the text(s) violate the policy.
|
||||
@ -163,6 +212,26 @@ def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -
|
||||
return flagged
|
||||
|
||||
|
||||
async def check_openai_async_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
|
||||
"""
|
||||
Calls the moderation endpoint to check if the text(s) violate the policy.
|
||||
See [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation) for more details.
|
||||
Returns true if any of the input is flagged as any of ['sexual', 'hate', 'violence', 'self-harm', 'sexual/minors', 'hate/threatening', 'violence/graphic'].
|
||||
"""
|
||||
response = await openai_async_request(url=OPENAI_MODERATION_URL, headers=headers, payload={"input": input})
|
||||
results = response["results"]
|
||||
flagged = any(res["flagged"] for res in results)
|
||||
if flagged:
|
||||
for result in results:
|
||||
if result["flagged"]:
|
||||
logger.debug(
|
||||
"OpenAI Moderation API flagged the text '%s' as a potential policy violation of the following categories: %s",
|
||||
input,
|
||||
result["categories"],
|
||||
)
|
||||
return flagged
|
||||
|
||||
|
||||
def _check_openai_finish_reason(result: Dict, payload: Dict) -> None:
|
||||
"""Check the `finish_reason` the answers returned by OpenAI completions endpoint.
|
||||
If the `finish_reason` is `length` or `content_filter`, log a warning to the user.
|
||||
|
@ -47,6 +47,7 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"requests",
|
||||
"httpx",
|
||||
"pydantic<2",
|
||||
"transformers==4.32.1",
|
||||
"pandas",
|
||||
|
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Add asyncio support to the OpenAI invocation layer.
|
@ -7,7 +7,7 @@ from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
|
||||
def test_default_api_base(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key")
|
||||
@ -19,7 +19,7 @@ def test_default_api_base(mock_request):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
|
||||
def test_custom_api_base(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
|
||||
|
@ -1046,36 +1046,59 @@ def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
|
||||
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
|
||||
def test_content_moderation_gpt_3_and_gpt_3_5():
|
||||
def test_content_moderation_gpt_3():
|
||||
"""
|
||||
Check all possible cases of the moderation checks passing / failing in a PromptNode call
|
||||
for both ChatGPTInvocationLayer and OpenAIInvocationLayer.
|
||||
Check all possible cases of the moderation checks passing / failing in a PromptNode uses
|
||||
OpenAIInvocationLayer.
|
||||
"""
|
||||
prompt_node_gpt_3_5 = PromptNode(
|
||||
model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True}
|
||||
)
|
||||
prompt_node_gpt_3 = PromptNode(
|
||||
prompt_node = PromptNode(
|
||||
model_name_or_path="text-davinci-003", api_key="key", model_kwargs={"moderate_content": True}
|
||||
)
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
|
||||
"haystack.nodes.prompt.invocation_layer.chatgpt.ChatGPTInvocationLayer._execute_openai_request"
|
||||
) as mock_execute_gpt_3_5, patch(
|
||||
"haystack.nodes.prompt.invocation_layer.open_ai.OpenAIInvocationLayer._execute_openai_request"
|
||||
) as mock_execute_gpt_3:
|
||||
"haystack.nodes.prompt.invocation_layer.open_ai.openai_request"
|
||||
) as mock_request:
|
||||
VIOLENT_TEXT = "some violent text"
|
||||
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
|
||||
# case 1: prompt fails the moderation check
|
||||
# prompt should not be sent to OpenAi & function should return an empty list
|
||||
mock_check.return_value = True
|
||||
assert prompt_node_gpt_3_5(VIOLENT_TEXT) == prompt_node_gpt_3(VIOLENT_TEXT) == []
|
||||
assert prompt_node(VIOLENT_TEXT) == []
|
||||
# case 2: prompt passes the moderation check but the generated output fails the check
|
||||
# function should also return an empty list
|
||||
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = [VIOLENT_TEXT]
|
||||
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == []
|
||||
mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]}
|
||||
assert prompt_node("normal prompt") == []
|
||||
# case 3: both prompt and output pass the moderation check
|
||||
# function should return the output
|
||||
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = ["normal output"]
|
||||
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == ["normal output"]
|
||||
mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]}
|
||||
assert prompt_node("normal prompt") == ["normal output"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
|
||||
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
|
||||
def test_content_moderation_gpt_35():
|
||||
"""
|
||||
Check all possible cases of the moderation checks passing / failing in a PromptNode uses
|
||||
ChatGPTInvocationLayer.
|
||||
"""
|
||||
prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True})
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
|
||||
"haystack.nodes.prompt.invocation_layer.open_ai.openai_request"
|
||||
) as mock_request:
|
||||
VIOLENT_TEXT = "some violent text"
|
||||
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
|
||||
# case 1: prompt fails the moderation check
|
||||
# prompt should not be sent to OpenAi & function should return an empty list
|
||||
mock_check.return_value = True
|
||||
assert prompt_node(VIOLENT_TEXT) == []
|
||||
# case 2: prompt passes the moderation check but the generated output fails the check
|
||||
# function should also return an empty list
|
||||
mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]}
|
||||
assert prompt_node("normal prompt") == []
|
||||
# case 3: both prompt and output pass the moderation check
|
||||
# function should return the output
|
||||
mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]}
|
||||
assert prompt_node("normal prompt") == ["normal output"]
|
||||
|
||||
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
|
Loading…
x
Reference in New Issue
Block a user