mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-17 19:09:09 +00:00
feat: add support for async openai calls (#5946)
* add support for async openai calls * add actual async call * split the async api * ask permission * Update haystack/utils/openai_utils.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Fix OpenAI content moderation tests * Fix ChatGPT invocation layer tests --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
parent
1ccf674d73
commit
ac408134f4
@ -7,11 +7,13 @@ import sseclient
|
|||||||
from haystack.errors import OpenAIError
|
from haystack.errors import OpenAIError
|
||||||
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters
|
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters
|
||||||
from haystack.utils.openai_utils import (
|
from haystack.utils.openai_utils import (
|
||||||
openai_request,
|
|
||||||
_openai_text_completion_tokenization_details,
|
_openai_text_completion_tokenization_details,
|
||||||
load_openai_tokenizer,
|
load_openai_tokenizer,
|
||||||
_check_openai_finish_reason,
|
_check_openai_finish_reason,
|
||||||
|
check_openai_async_policy_violation,
|
||||||
check_openai_policy_violation,
|
check_openai_policy_violation,
|
||||||
|
openai_async_request,
|
||||||
|
openai_request,
|
||||||
)
|
)
|
||||||
from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer
|
from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer
|
||||||
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
|
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
|
||||||
@ -112,17 +114,13 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
headers["OpenAI-Organization"] = self.openai_organization
|
headers["OpenAI-Organization"] = self.openai_organization
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def invoke(self, *args, **kwargs):
|
def _prepare_invoke(self, *args, **kwargs):
|
||||||
"""
|
|
||||||
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
|
|
||||||
and returns a list of responses using a REST invocation.
|
|
||||||
|
|
||||||
:return: The responses are being returned.
|
|
||||||
|
|
||||||
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
|
|
||||||
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
|
||||||
"""
|
|
||||||
prompt = kwargs.get("prompt")
|
prompt = kwargs.get("prompt")
|
||||||
|
if not prompt:
|
||||||
|
raise ValueError(
|
||||||
|
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
||||||
|
f"Make sure to provide prompt in kwargs."
|
||||||
|
)
|
||||||
# either stream is True (will use default handler) or stream_handler is provided
|
# either stream is True (will use default handler) or stream_handler is provided
|
||||||
kwargs_with_defaults = self.model_input_kwargs
|
kwargs_with_defaults = self.model_input_kwargs
|
||||||
if kwargs:
|
if kwargs:
|
||||||
@ -150,23 +148,25 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
||||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return (prompt, base_payload, kwargs_with_defaults, stream, moderation)
|
||||||
|
|
||||||
|
def invoke(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
|
||||||
|
and returns a list of responses using a REST invocation.
|
||||||
|
|
||||||
|
:return: The responses are being returned.
|
||||||
|
|
||||||
|
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
|
||||||
|
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
||||||
|
"""
|
||||||
|
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)
|
||||||
|
|
||||||
if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
|
if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
|
||||||
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||||
return []
|
return []
|
||||||
responses = self._execute_openai_request(
|
|
||||||
prompt=prompt, base_payload=base_payload, kwargs_with_defaults=kwargs_with_defaults, stream=stream
|
|
||||||
)
|
|
||||||
if moderation and check_openai_policy_violation(input=responses, headers=self.headers):
|
|
||||||
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
|
|
||||||
return []
|
|
||||||
return responses
|
|
||||||
|
|
||||||
def _execute_openai_request(self, prompt: str, base_payload: Dict, kwargs_with_defaults: Dict, stream: bool):
|
|
||||||
if not prompt:
|
|
||||||
raise ValueError(
|
|
||||||
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
|
||||||
f"Make sure to provide prompt in kwargs."
|
|
||||||
)
|
|
||||||
extra_payload = {
|
extra_payload = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"suffix": kwargs_with_defaults.get("suffix", None),
|
"suffix": kwargs_with_defaults.get("suffix", None),
|
||||||
@ -179,13 +179,52 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
res = openai_request(url=self.url, headers=self.headers, payload=payload)
|
res = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||||
_check_openai_finish_reason(result=res, payload=payload)
|
_check_openai_finish_reason(result=res, payload=payload)
|
||||||
responses = [ans["text"].strip() for ans in res["choices"]]
|
responses = [ans["text"].strip() for ans in res["choices"]]
|
||||||
return responses
|
|
||||||
else:
|
else:
|
||||||
response = openai_request(
|
response = openai_request(
|
||||||
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||||
)
|
)
|
||||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||||
return self._process_streaming_response(response=response, stream_handler=handler)
|
responses = self._process_streaming_response(response=response, stream_handler=handler)
|
||||||
|
|
||||||
|
if moderation and check_openai_policy_violation(input=responses, headers=self.headers):
|
||||||
|
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
|
||||||
|
return []
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
async def ainvoke(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
asyncio version of the `invoke` method.
|
||||||
|
"""
|
||||||
|
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)
|
||||||
|
if moderation and await check_openai_async_policy_violation(input=prompt, headers=self.headers):
|
||||||
|
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||||
|
return []
|
||||||
|
|
||||||
|
extra_payload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"suffix": kwargs_with_defaults.get("suffix", None),
|
||||||
|
"logprobs": kwargs_with_defaults.get("logprobs", None),
|
||||||
|
"echo": kwargs_with_defaults.get("echo", False),
|
||||||
|
"best_of": kwargs_with_defaults.get("best_of", 1),
|
||||||
|
}
|
||||||
|
payload = {**base_payload, **extra_payload}
|
||||||
|
if not stream:
|
||||||
|
res = await openai_async_request(url=self.url, headers=self.headers, payload=payload)
|
||||||
|
_check_openai_finish_reason(result=res, payload=payload)
|
||||||
|
responses = [ans["text"].strip() for ans in res["choices"]]
|
||||||
|
else:
|
||||||
|
response = await openai_async_request(
|
||||||
|
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||||
|
)
|
||||||
|
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||||
|
responses = self._process_streaming_response(response=response, stream_handler=handler)
|
||||||
|
|
||||||
|
if moderation and await check_openai_async_policy_violation(input=responses, headers=self.headers):
|
||||||
|
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
|
||||||
|
return []
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler):
|
def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler):
|
||||||
client = sseclient.SSEClient(response)
|
client = sseclient.SSEClient(response)
|
||||||
|
@ -115,11 +115,11 @@ class PromptModel(BaseComponent):
|
|||||||
"""
|
"""
|
||||||
Drop-in replacement asyncio version of the `invoke` method, see there for documentation.
|
Drop-in replacement asyncio version of the `invoke` method, see there for documentation.
|
||||||
"""
|
"""
|
||||||
try:
|
if hasattr(self.model_invocation_layer, "ainvoke"):
|
||||||
return await self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
|
return await self.model_invocation_layer.ainvoke(prompt=prompt, **kwargs)
|
||||||
except TypeError:
|
|
||||||
# The `invoke` method of the underlying invocation layer doesn't support asyncio
|
# The underlying invocation layer doesn't support asyncio
|
||||||
return self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
|
return self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def _ensure_token_limit(self, prompt: str) -> str:
|
def _ensure_token_limit(self, prompt: str) -> str:
|
||||||
|
@ -3,7 +3,9 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import platform
|
import platform
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Union, Tuple, Optional, List
|
from typing import Dict, Union, Tuple, Optional, List, cast
|
||||||
|
|
||||||
|
import httpx
|
||||||
import requests
|
import requests
|
||||||
import tenacity
|
import tenacity
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@ -143,6 +145,53 @@ def openai_request(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@tenacity.retry(
|
||||||
|
reraise=True,
|
||||||
|
retry=tenacity.retry_if_exception_type(OpenAIError)
|
||||||
|
and tenacity.retry_if_not_exception_type(OpenAIUnauthorizedError),
|
||||||
|
wait=tenacity.wait_exponential(multiplier=OPENAI_BACKOFF),
|
||||||
|
stop=tenacity.stop_after_attempt(OPENAI_MAX_RETRIES),
|
||||||
|
)
|
||||||
|
async def openai_async_request(
|
||||||
|
url: str,
|
||||||
|
headers: Dict,
|
||||||
|
payload: Dict,
|
||||||
|
timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT,
|
||||||
|
read_response: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`.
|
||||||
|
|
||||||
|
See `openai_request`.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.request(
|
||||||
|
"POST", url, headers=headers, json=payload, timeout=cast(float, timeout), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if read_response:
|
||||||
|
json_response = json.loads(response.text)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
openai_error: OpenAIError
|
||||||
|
if response.status_code == 429:
|
||||||
|
openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
|
||||||
|
elif response.status_code == 401:
|
||||||
|
openai_error = OpenAIUnauthorizedError(f"API key is invalid: {response.text}")
|
||||||
|
else:
|
||||||
|
openai_error = OpenAIError(
|
||||||
|
f"OpenAI returned an error.\n"
|
||||||
|
f"Status code: {response.status_code}\n"
|
||||||
|
f"Response body: {response.text}",
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
raise openai_error
|
||||||
|
if read_response:
|
||||||
|
return json_response
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
|
def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
|
||||||
"""
|
"""
|
||||||
Calls the moderation endpoint to check if the text(s) violate the policy.
|
Calls the moderation endpoint to check if the text(s) violate the policy.
|
||||||
@ -163,6 +212,26 @@ def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -
|
|||||||
return flagged
|
return flagged
|
||||||
|
|
||||||
|
|
||||||
|
async def check_openai_async_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
|
||||||
|
"""
|
||||||
|
Calls the moderation endpoint to check if the text(s) violate the policy.
|
||||||
|
See [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation) for more details.
|
||||||
|
Returns true if any of the input is flagged as any of ['sexual', 'hate', 'violence', 'self-harm', 'sexual/minors', 'hate/threatening', 'violence/graphic'].
|
||||||
|
"""
|
||||||
|
response = await openai_async_request(url=OPENAI_MODERATION_URL, headers=headers, payload={"input": input})
|
||||||
|
results = response["results"]
|
||||||
|
flagged = any(res["flagged"] for res in results)
|
||||||
|
if flagged:
|
||||||
|
for result in results:
|
||||||
|
if result["flagged"]:
|
||||||
|
logger.debug(
|
||||||
|
"OpenAI Moderation API flagged the text '%s' as a potential policy violation of the following categories: %s",
|
||||||
|
input,
|
||||||
|
result["categories"],
|
||||||
|
)
|
||||||
|
return flagged
|
||||||
|
|
||||||
|
|
||||||
def _check_openai_finish_reason(result: Dict, payload: Dict) -> None:
|
def _check_openai_finish_reason(result: Dict, payload: Dict) -> None:
|
||||||
"""Check the `finish_reason` the answers returned by OpenAI completions endpoint.
|
"""Check the `finish_reason` the answers returned by OpenAI completions endpoint.
|
||||||
If the `finish_reason` is `length` or `content_filter`, log a warning to the user.
|
If the `finish_reason` is `length` or `content_filter`, log a warning to the user.
|
||||||
|
@ -47,6 +47,7 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"requests",
|
"requests",
|
||||||
|
"httpx",
|
||||||
"pydantic<2",
|
"pydantic<2",
|
||||||
"transformers==4.32.1",
|
"transformers==4.32.1",
|
||||||
"pandas",
|
"pandas",
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
Add asyncio support to the OpenAI invocation layer.
|
@ -7,7 +7,7 @@ from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
|
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
|
||||||
def test_default_api_base(mock_request):
|
def test_default_api_base(mock_request):
|
||||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||||
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key")
|
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key")
|
||||||
@ -19,7 +19,7 @@ def test_default_api_base(mock_request):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
|
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
|
||||||
def test_custom_api_base(mock_request):
|
def test_custom_api_base(mock_request):
|
||||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||||
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
|
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
|
||||||
|
@ -1046,36 +1046,59 @@ def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
|
|||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
|
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
|
||||||
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
|
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
|
||||||
def test_content_moderation_gpt_3_and_gpt_3_5():
|
def test_content_moderation_gpt_3():
|
||||||
"""
|
"""
|
||||||
Check all possible cases of the moderation checks passing / failing in a PromptNode call
|
Check all possible cases of the moderation checks passing / failing in a PromptNode uses
|
||||||
for both ChatGPTInvocationLayer and OpenAIInvocationLayer.
|
OpenAIInvocationLayer.
|
||||||
"""
|
"""
|
||||||
prompt_node_gpt_3_5 = PromptNode(
|
prompt_node = PromptNode(
|
||||||
model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True}
|
|
||||||
)
|
|
||||||
prompt_node_gpt_3 = PromptNode(
|
|
||||||
model_name_or_path="text-davinci-003", api_key="key", model_kwargs={"moderate_content": True}
|
model_name_or_path="text-davinci-003", api_key="key", model_kwargs={"moderate_content": True}
|
||||||
)
|
)
|
||||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
|
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
|
||||||
"haystack.nodes.prompt.invocation_layer.chatgpt.ChatGPTInvocationLayer._execute_openai_request"
|
"haystack.nodes.prompt.invocation_layer.open_ai.openai_request"
|
||||||
) as mock_execute_gpt_3_5, patch(
|
) as mock_request:
|
||||||
"haystack.nodes.prompt.invocation_layer.open_ai.OpenAIInvocationLayer._execute_openai_request"
|
|
||||||
) as mock_execute_gpt_3:
|
|
||||||
VIOLENT_TEXT = "some violent text"
|
VIOLENT_TEXT = "some violent text"
|
||||||
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
|
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
|
||||||
# case 1: prompt fails the moderation check
|
# case 1: prompt fails the moderation check
|
||||||
# prompt should not be sent to OpenAi & function should return an empty list
|
# prompt should not be sent to OpenAi & function should return an empty list
|
||||||
mock_check.return_value = True
|
mock_check.return_value = True
|
||||||
assert prompt_node_gpt_3_5(VIOLENT_TEXT) == prompt_node_gpt_3(VIOLENT_TEXT) == []
|
assert prompt_node(VIOLENT_TEXT) == []
|
||||||
# case 2: prompt passes the moderation check but the generated output fails the check
|
# case 2: prompt passes the moderation check but the generated output fails the check
|
||||||
# function should also return an empty list
|
# function should also return an empty list
|
||||||
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = [VIOLENT_TEXT]
|
mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]}
|
||||||
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == []
|
assert prompt_node("normal prompt") == []
|
||||||
# case 3: both prompt and output pass the moderation check
|
# case 3: both prompt and output pass the moderation check
|
||||||
# function should return the output
|
# function should return the output
|
||||||
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = ["normal output"]
|
mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]}
|
||||||
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == ["normal output"]
|
assert prompt_node("normal prompt") == ["normal output"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
|
||||||
|
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
|
||||||
|
def test_content_moderation_gpt_35():
|
||||||
|
"""
|
||||||
|
Check all possible cases of the moderation checks passing / failing in a PromptNode uses
|
||||||
|
ChatGPTInvocationLayer.
|
||||||
|
"""
|
||||||
|
prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True})
|
||||||
|
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
|
||||||
|
"haystack.nodes.prompt.invocation_layer.open_ai.openai_request"
|
||||||
|
) as mock_request:
|
||||||
|
VIOLENT_TEXT = "some violent text"
|
||||||
|
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
|
||||||
|
# case 1: prompt fails the moderation check
|
||||||
|
# prompt should not be sent to OpenAi & function should return an empty list
|
||||||
|
mock_check.return_value = True
|
||||||
|
assert prompt_node(VIOLENT_TEXT) == []
|
||||||
|
# case 2: prompt passes the moderation check but the generated output fails the check
|
||||||
|
# function should also return an empty list
|
||||||
|
mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]}
|
||||||
|
assert prompt_node("normal prompt") == []
|
||||||
|
# case 3: both prompt and output pass the moderation check
|
||||||
|
# function should return the output
|
||||||
|
mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]}
|
||||||
|
assert prompt_node("normal prompt") == ["normal output"]
|
||||||
|
|
||||||
|
|
||||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user