mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-06 20:06:55 +00:00
feat: Optional Content Moderation for OpenAI PromptNode & OpenAIAnswerGenerator (#5017)
* #4071 implemented optional content moderation for OpenAI PromptNode * added two simple integration tests * improved documentation & renamed _invoke method to _execute_openai_request * added a flag to check_openai_policy_violation that will return a full dict of all text violations and their categories * re-implemented the tests as unit tests & without use of the OpenAI APIs * removed unused patch * changed check_openai_policy_violation back to only return a bool * fixed pylint and test error --------- Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
97f136b901
commit
1318ac5074
@ -11,6 +11,7 @@ from haystack.utils.openai_utils import (
|
|||||||
openai_request,
|
openai_request,
|
||||||
_openai_text_completion_tokenization_details,
|
_openai_text_completion_tokenization_details,
|
||||||
_check_openai_finish_reason,
|
_check_openai_finish_reason,
|
||||||
|
check_openai_policy_violation,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -45,6 +46,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
|||||||
progress_bar: bool = True,
|
progress_bar: bool = True,
|
||||||
prompt_template: Optional[PromptTemplate] = None,
|
prompt_template: Optional[PromptTemplate] = None,
|
||||||
context_join_str: str = " ",
|
context_join_str: str = " ",
|
||||||
|
moderate_content: bool = False,
|
||||||
api_base: str = "https://api.openai.com/v1",
|
api_base: str = "https://api.openai.com/v1",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -99,6 +101,9 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
|||||||
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
|
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
|
||||||
:param context_join_str: The separation string used to join the input documents to create the context
|
:param context_join_str: The separation string used to join the input documents to create the context
|
||||||
used by the PromptTemplate.
|
used by the PromptTemplate.
|
||||||
|
:param moderate_content: Whether to filter input and generated answers for potentially sensitive content
|
||||||
|
using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation). If the input or
|
||||||
|
answers are flagged, an empty list is returned in place of the answers.
|
||||||
:param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`.
|
:param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`.
|
||||||
"""
|
"""
|
||||||
super().__init__(progress_bar=progress_bar)
|
super().__init__(progress_bar=progress_bar)
|
||||||
@ -159,6 +164,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
|||||||
self.prompt_template = prompt_template
|
self.prompt_template = prompt_template
|
||||||
self.context_join_str = context_join_str
|
self.context_join_str = context_join_str
|
||||||
self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None
|
self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None
|
||||||
|
self.moderate_content = moderate_content
|
||||||
|
|
||||||
tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model)
|
tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model)
|
||||||
|
|
||||||
@ -228,9 +234,19 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
|||||||
else:
|
else:
|
||||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
if self.moderate_content and check_openai_policy_violation(input=prompt, headers=headers):
|
||||||
|
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||||
|
return {"query": query, "answers": []}
|
||||||
|
|
||||||
|
logger.debug("Prompt being sent to OpenAI API with prompt %s.", prompt)
|
||||||
res = openai_request(url=url, headers=headers, payload=payload, timeout=timeout)
|
res = openai_request(url=url, headers=headers, payload=payload, timeout=timeout)
|
||||||
_check_openai_finish_reason(result=res, payload=payload)
|
_check_openai_finish_reason(result=res, payload=payload)
|
||||||
generated_answers = [ans["text"] for ans in res["choices"]]
|
generated_answers = [ans["text"] for ans in res["choices"]]
|
||||||
|
if self.moderate_content and check_openai_policy_violation(input=generated_answers, headers=headers):
|
||||||
|
logger.info(
|
||||||
|
"Generated answers '%s' will not be returned due to potential policy violation.", generated_answers
|
||||||
|
)
|
||||||
|
return {"query": query, "answers": []}
|
||||||
answers = self._create_answers(generated_answers, input_docs, prompt=prompt)
|
answers = self._create_answers(generated_answers, input_docs, prompt=prompt)
|
||||||
result = {"query": query, "answers": answers}
|
result = {"query": query, "answers": answers}
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -36,20 +36,18 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
|||||||
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
||||||
:param kwargs: Additional keyword arguments passed to the underlying model.
|
:param kwargs: Additional keyword arguments passed to the underlying model.
|
||||||
[See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
|
[See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
|
||||||
|
Note: additional model argument moderate_content will filter input and generated answers for potentially
|
||||||
|
sensitive content using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation)
|
||||||
|
if set. If the input or answers are flagged, an empty list is returned in place of the answers.
|
||||||
"""
|
"""
|
||||||
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
|
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
|
||||||
|
|
||||||
def invoke(self, *args, **kwargs):
|
def _execute_openai_request(
|
||||||
|
self, prompt: Union[str, List[Dict]], base_payload: Dict, kwargs_with_defaults: Dict, stream: bool
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
It takes in either a prompt or a list of messages and returns a list of responses, using a REST invocation.
|
|
||||||
|
|
||||||
:return: A list of generated responses.
|
|
||||||
|
|
||||||
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
|
|
||||||
For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat).
|
For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat).
|
||||||
"""
|
"""
|
||||||
prompt = kwargs.get("prompt", None)
|
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
|
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
|
||||||
@ -60,34 +58,8 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
|||||||
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
|
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
|
||||||
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
|
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
|
||||||
)
|
)
|
||||||
|
extra_payload = {"messages": messages}
|
||||||
kwargs_with_defaults = self.model_input_kwargs
|
payload = {**base_payload, **extra_payload}
|
||||||
if kwargs:
|
|
||||||
# we use keyword stop_words but OpenAI uses stop
|
|
||||||
if "stop_words" in kwargs:
|
|
||||||
kwargs["stop"] = kwargs.pop("stop_words")
|
|
||||||
if "top_k" in kwargs:
|
|
||||||
top_k = kwargs.pop("top_k")
|
|
||||||
kwargs["n"] = top_k
|
|
||||||
kwargs["best_of"] = top_k
|
|
||||||
kwargs_with_defaults.update(kwargs)
|
|
||||||
|
|
||||||
stream = (
|
|
||||||
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
|
||||||
)
|
|
||||||
payload = {
|
|
||||||
"model": self.model_name_or_path,
|
|
||||||
"messages": messages,
|
|
||||||
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
|
|
||||||
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
|
||||||
"top_p": kwargs_with_defaults.get("top_p", 1),
|
|
||||||
"n": kwargs_with_defaults.get("n", 1),
|
|
||||||
"stream": stream,
|
|
||||||
"stop": kwargs_with_defaults.get("stop", None),
|
|
||||||
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
|
|
||||||
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
|
||||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
|
||||||
}
|
|
||||||
if not stream:
|
if not stream:
|
||||||
response = openai_request(url=self.url, headers=self.headers, payload=payload)
|
response = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||||
_check_openai_finish_reason(result=response, payload=payload)
|
_check_openai_finish_reason(result=response, payload=payload)
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from haystack.utils.openai_utils import (
|
|||||||
_openai_text_completion_tokenization_details,
|
_openai_text_completion_tokenization_details,
|
||||||
load_openai_tokenizer,
|
load_openai_tokenizer,
|
||||||
_check_openai_finish_reason,
|
_check_openai_finish_reason,
|
||||||
|
check_openai_policy_violation,
|
||||||
)
|
)
|
||||||
from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer
|
from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer
|
||||||
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
|
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
|
||||||
@ -47,6 +48,9 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
kwargs includes: suffix, temperature, top_p, presence_penalty, frequency_penalty, best_of, n, max_tokens,
|
kwargs includes: suffix, temperature, top_p, presence_penalty, frequency_penalty, best_of, n, max_tokens,
|
||||||
logit_bias, stop, echo, and logprobs. For more details about these kwargs, see OpenAI
|
logit_bias, stop, echo, and logprobs. For more details about these kwargs, see OpenAI
|
||||||
[documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
[documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
||||||
|
Note: additional model argument moderate_content will filter input and generated answers for potentially
|
||||||
|
sensitive content using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation)
|
||||||
|
if set. If the input or answers are flagged, an empty list is returned in place of the answers.
|
||||||
"""
|
"""
|
||||||
super().__init__(model_name_or_path)
|
super().__init__(model_name_or_path)
|
||||||
if not isinstance(api_key, str) or len(api_key) == 0:
|
if not isinstance(api_key, str) or len(api_key) == 0:
|
||||||
@ -81,6 +85,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
"logit_bias",
|
"logit_bias",
|
||||||
"stream",
|
"stream",
|
||||||
"stream_handler",
|
"stream_handler",
|
||||||
|
"moderate_content",
|
||||||
]
|
]
|
||||||
if key in kwargs
|
if key in kwargs
|
||||||
}
|
}
|
||||||
@ -101,7 +106,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
|
|
||||||
def invoke(self, *args, **kwargs):
|
def invoke(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation.
|
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
|
||||||
|
and returns a list of responses using a REST invocation.
|
||||||
|
|
||||||
:return: The responses are being returned.
|
:return: The responses are being returned.
|
||||||
|
|
||||||
@ -109,12 +115,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
||||||
"""
|
"""
|
||||||
prompt = kwargs.get("prompt")
|
prompt = kwargs.get("prompt")
|
||||||
if not prompt:
|
# either stream is True (will use default handler) or stream_handler is provided
|
||||||
raise ValueError(
|
|
||||||
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
|
||||||
f"Make sure to provide prompt in kwargs."
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs_with_defaults = self.model_input_kwargs
|
kwargs_with_defaults = self.model_input_kwargs
|
||||||
if kwargs:
|
if kwargs:
|
||||||
# we use keyword stop_words but OpenAI uses stop
|
# we use keyword stop_words but OpenAI uses stop
|
||||||
@ -125,28 +126,47 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
kwargs["n"] = top_k
|
kwargs["n"] = top_k
|
||||||
kwargs["best_of"] = top_k
|
kwargs["best_of"] = top_k
|
||||||
kwargs_with_defaults.update(kwargs)
|
kwargs_with_defaults.update(kwargs)
|
||||||
|
|
||||||
# either stream is True (will use default handler) or stream_handler is provided
|
|
||||||
stream = (
|
stream = (
|
||||||
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
||||||
)
|
)
|
||||||
payload = {
|
moderation = kwargs_with_defaults.get("moderate_content", False)
|
||||||
|
base_payload = { # payload common to all OpenAI models
|
||||||
"model": self.model_name_or_path,
|
"model": self.model_name_or_path,
|
||||||
"prompt": prompt,
|
|
||||||
"suffix": kwargs_with_defaults.get("suffix", None),
|
|
||||||
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
|
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
|
||||||
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
||||||
"top_p": kwargs_with_defaults.get("top_p", 1),
|
"top_p": kwargs_with_defaults.get("top_p", 1),
|
||||||
"n": kwargs_with_defaults.get("n", 1),
|
"n": kwargs_with_defaults.get("n", 1),
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"logprobs": kwargs_with_defaults.get("logprobs", None),
|
|
||||||
"echo": kwargs_with_defaults.get("echo", False),
|
|
||||||
"stop": kwargs_with_defaults.get("stop", None),
|
"stop": kwargs_with_defaults.get("stop", None),
|
||||||
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
|
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
|
||||||
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
||||||
"best_of": kwargs_with_defaults.get("best_of", 1),
|
|
||||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
||||||
}
|
}
|
||||||
|
if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
|
||||||
|
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||||
|
return []
|
||||||
|
responses = self._execute_openai_request(
|
||||||
|
prompt=prompt, base_payload=base_payload, kwargs_with_defaults=kwargs_with_defaults, stream=stream
|
||||||
|
)
|
||||||
|
if moderation and check_openai_policy_violation(input=responses, headers=self.headers):
|
||||||
|
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
|
||||||
|
return []
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def _execute_openai_request(self, prompt: str, base_payload: Dict, kwargs_with_defaults: Dict, stream: bool):
|
||||||
|
if not prompt:
|
||||||
|
raise ValueError(
|
||||||
|
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
||||||
|
f"Make sure to provide prompt in kwargs."
|
||||||
|
)
|
||||||
|
extra_payload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"suffix": kwargs_with_defaults.get("suffix", None),
|
||||||
|
"logprobs": kwargs_with_defaults.get("logprobs", None),
|
||||||
|
"echo": kwargs_with_defaults.get("echo", False),
|
||||||
|
"best_of": kwargs_with_defaults.get("best_of", 1),
|
||||||
|
}
|
||||||
|
payload = {**base_payload, **extra_payload}
|
||||||
if not stream:
|
if not stream:
|
||||||
res = openai_request(url=self.url, headers=self.headers, payload=payload)
|
res = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||||
_check_openai_finish_reason(result=res, payload=payload)
|
_check_openai_finish_reason(result=res, payload=payload)
|
||||||
|
|||||||
@ -18,7 +18,6 @@ from haystack.environment import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
machine = platform.machine().lower()
|
machine = platform.machine().lower()
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
|
|
||||||
@ -26,6 +25,8 @@ OPENAI_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
|
|||||||
OPENAI_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
|
OPENAI_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
|
||||||
OPENAI_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
|
OPENAI_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
|
||||||
|
|
||||||
|
OPENAI_MODERATION_URL = "https://api.openai.com/v1/moderations"
|
||||||
|
|
||||||
|
|
||||||
def load_openai_tokenizer(tokenizer_name: str):
|
def load_openai_tokenizer(tokenizer_name: str):
|
||||||
"""Load either the tokenizer from tiktoken (if the library is available) or fallback to the GPT2TokenizerFast
|
"""Load either the tokenizer from tiktoken (if the library is available) or fallback to the GPT2TokenizerFast
|
||||||
@ -150,6 +151,26 @@ def openai_request(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
|
||||||
|
"""
|
||||||
|
Calls the moderation endpoint to check if the text(s) violate the policy.
|
||||||
|
See [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation) for more details.
|
||||||
|
Returns true if any of the input is flagged as any of ['sexual', 'hate', 'violence', 'self-harm', 'sexual/minors', 'hate/threatening', 'violence/graphic'].
|
||||||
|
"""
|
||||||
|
response = openai_request(url=OPENAI_MODERATION_URL, headers=headers, payload={"input": input})
|
||||||
|
results = response["results"]
|
||||||
|
flagged = any(res["flagged"] for res in results)
|
||||||
|
if flagged:
|
||||||
|
for result in results:
|
||||||
|
if result["flagged"]:
|
||||||
|
logger.debug(
|
||||||
|
"OpenAI Moderation API flagged the text '%s' as a potential policy violation of the following categories: %s",
|
||||||
|
input,
|
||||||
|
result["categories"],
|
||||||
|
)
|
||||||
|
return flagged
|
||||||
|
|
||||||
|
|
||||||
def _check_openai_finish_reason(result: Dict, payload: Dict) -> None:
|
def _check_openai_finish_reason(result: Dict, payload: Dict) -> None:
|
||||||
"""Check the `finish_reason` the answers returned by OpenAI completions endpoint.
|
"""Check the `finish_reason` the answers returned by OpenAI completions endpoint.
|
||||||
If the `finish_reason` is `length` or `content_filter`, log a warning to the user.
|
If the `finish_reason` is `length` or `content_filter`, log a warning to the user.
|
||||||
|
|||||||
@ -1013,6 +1013,40 @@ def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
|
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
|
||||||
|
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
|
||||||
|
def test_content_moderation_gpt_3_and_gpt_3_5():
|
||||||
|
"""
|
||||||
|
Check all possible cases of the moderation checks passing / failing in a PromptNode call
|
||||||
|
for both ChatGPTInvocationLayer and OpenAIInvocationLayer.
|
||||||
|
"""
|
||||||
|
prompt_node_gpt_3_5 = PromptNode(
|
||||||
|
model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True}
|
||||||
|
)
|
||||||
|
prompt_node_gpt_3 = PromptNode(
|
||||||
|
model_name_or_path="text-davinci-003", api_key="key", model_kwargs={"moderate_content": True}
|
||||||
|
)
|
||||||
|
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
|
||||||
|
"haystack.nodes.prompt.invocation_layer.chatgpt.ChatGPTInvocationLayer._execute_openai_request"
|
||||||
|
) as mock_execute_gpt_3_5, patch(
|
||||||
|
"haystack.nodes.prompt.invocation_layer.open_ai.OpenAIInvocationLayer._execute_openai_request"
|
||||||
|
) as mock_execute_gpt_3:
|
||||||
|
VIOLENT_TEXT = "some violent text"
|
||||||
|
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
|
||||||
|
# case 1: prompt fails the moderation check
|
||||||
|
# prompt should not be sent to OpenAi & function should return an empty list
|
||||||
|
mock_check.return_value = True
|
||||||
|
assert prompt_node_gpt_3_5(VIOLENT_TEXT) == prompt_node_gpt_3(VIOLENT_TEXT) == []
|
||||||
|
# case 2: prompt passes the moderation check but the generated output fails the check
|
||||||
|
# function should also return an empty list
|
||||||
|
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = [VIOLENT_TEXT]
|
||||||
|
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == []
|
||||||
|
# case 3: both prompt and output pass the moderation check
|
||||||
|
# function should return the output
|
||||||
|
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = ["normal output"]
|
||||||
|
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == ["normal output"]
|
||||||
|
|
||||||
|
|
||||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||||
def test_prompt_node_warns_about_missing_documents(mock_model, caplog):
|
def test_prompt_node_warns_about_missing_documents(mock_model, caplog):
|
||||||
lfqa_prompt = PromptTemplate(
|
lfqa_prompt = PromptTemplate(
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -5,7 +7,11 @@ import pytest
|
|||||||
from tenacity import wait_none
|
from tenacity import wait_none
|
||||||
|
|
||||||
from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError
|
from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError
|
||||||
from haystack.utils.openai_utils import openai_request, _openai_text_completion_tokenization_details
|
from haystack.utils.openai_utils import (
|
||||||
|
openai_request,
|
||||||
|
_openai_text_completion_tokenization_details,
|
||||||
|
check_openai_policy_violation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@ -98,3 +104,46 @@ def test_openai_request_does_not_retry_on_success(mock_requests):
|
|||||||
openai_request.retry_with(wait=wait_none())(url="some_url", headers={}, payload={}, read_response=False)
|
openai_request.retry_with(wait=wait_none())(url="some_url", headers={}, payload={}, read_response=False)
|
||||||
|
|
||||||
assert mock_requests.request.call_count == 1
|
assert mock_requests.request.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_check_openai_policy_violation():
|
||||||
|
moderation_endpoint_mock_response_flagged = {
|
||||||
|
"id": "modr-7Ok9zndoeSn5ij654vuNCgFVomU4U",
|
||||||
|
"model": "text-moderation-004",
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"flagged": True,
|
||||||
|
"categories": {
|
||||||
|
"sexual": False,
|
||||||
|
"hate": False,
|
||||||
|
"violence": True,
|
||||||
|
"self-harm": True,
|
||||||
|
"sexual/minors": False,
|
||||||
|
"hate/threatening": False,
|
||||||
|
"violence/graphic": False,
|
||||||
|
},
|
||||||
|
"category_scores": {
|
||||||
|
"sexual": 2.6659495e-06,
|
||||||
|
"hate": 1.9359974e-05,
|
||||||
|
"violence": 0.95964026,
|
||||||
|
"self-harm": 0.9696306,
|
||||||
|
"sexual/minors": 4.1061935e-07,
|
||||||
|
"hate/threatening": 4.9856953e-07,
|
||||||
|
"violence/graphic": 0.2683866,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
moderation_endpoint_mock_response_not_flagged = copy.deepcopy(moderation_endpoint_mock_response_flagged)
|
||||||
|
moderation_endpoint_mock_response_not_flagged["results"][0]["flagged"] = False
|
||||||
|
moderation_endpoint_mock_response_not_flagged["results"][0]["categories"].update(
|
||||||
|
{"violence": False, "self-harm": False}
|
||||||
|
)
|
||||||
|
with patch("haystack.utils.openai_utils.openai_request") as mock_openai_request:
|
||||||
|
# check that the function returns True if the input is flagged
|
||||||
|
mock_openai_request.return_value = moderation_endpoint_mock_response_flagged
|
||||||
|
assert check_openai_policy_violation(input="violent input", headers={}) == True
|
||||||
|
# check that the function returns False if the input is not flagged
|
||||||
|
mock_openai_request.return_value = moderation_endpoint_mock_response_not_flagged
|
||||||
|
assert check_openai_policy_violation(input="ok input", headers={}) == False
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user