feat: Optional Content Moderation for OpenAI PromptNode & OpenAIAnswerGenerator (#5017)

* #4071 implemented optional content moderation for OpenAI PromptNode

* added two simple integration tests

* improved documentation & renamed _invoke method to _execute_openai_request

* added a flag to check_openai_policy_violation that will return a full dict of all text violations and their categories

* re-implemented the tests as unit tests & without use of the OpenAI APIs

* removed unused patch

* changed check_openai_policy_violation back to only return a bool

* fixed pylint and test error

---------

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
Ben Heckmann 2023-06-19 13:27:11 +02:00 committed by GitHub
parent 97f136b901
commit 1318ac5074
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 165 additions and 53 deletions

View File

@ -11,6 +11,7 @@ from haystack.utils.openai_utils import (
openai_request, openai_request,
_openai_text_completion_tokenization_details, _openai_text_completion_tokenization_details,
_check_openai_finish_reason, _check_openai_finish_reason,
check_openai_policy_violation,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,6 +46,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
progress_bar: bool = True, progress_bar: bool = True,
prompt_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None,
context_join_str: str = " ", context_join_str: str = " ",
moderate_content: bool = False,
api_base: str = "https://api.openai.com/v1", api_base: str = "https://api.openai.com/v1",
): ):
""" """
@ -99,6 +101,9 @@ class OpenAIAnswerGenerator(BaseGenerator):
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure). [PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
:param context_join_str: The separation string used to join the input documents to create the context :param context_join_str: The separation string used to join the input documents to create the context
used by the PromptTemplate. used by the PromptTemplate.
:param moderate_content: Whether to filter input and generated answers for potentially sensitive content
using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation). If the input or
answers are flagged, an empty list is returned in place of the answers.
:param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`. :param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`.
""" """
super().__init__(progress_bar=progress_bar) super().__init__(progress_bar=progress_bar)
@ -159,6 +164,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
self.prompt_template = prompt_template self.prompt_template = prompt_template
self.context_join_str = context_join_str self.context_join_str = context_join_str
self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None
self.moderate_content = moderate_content
tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model) tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model)
@ -228,9 +234,19 @@ class OpenAIAnswerGenerator(BaseGenerator):
else: else:
headers["Authorization"] = f"Bearer {self.api_key}" headers["Authorization"] = f"Bearer {self.api_key}"
if self.moderate_content and check_openai_policy_violation(input=prompt, headers=headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
return {"query": query, "answers": []}
logger.debug("Prompt being sent to OpenAI API with prompt %s.", prompt)
res = openai_request(url=url, headers=headers, payload=payload, timeout=timeout) res = openai_request(url=url, headers=headers, payload=payload, timeout=timeout)
_check_openai_finish_reason(result=res, payload=payload) _check_openai_finish_reason(result=res, payload=payload)
generated_answers = [ans["text"] for ans in res["choices"]] generated_answers = [ans["text"] for ans in res["choices"]]
if self.moderate_content and check_openai_policy_violation(input=generated_answers, headers=headers):
logger.info(
"Generated answers '%s' will not be returned due to potential policy violation.", generated_answers
)
return {"query": query, "answers": []}
answers = self._create_answers(generated_answers, input_docs, prompt=prompt) answers = self._create_answers(generated_answers, input_docs, prompt=prompt)
result = {"query": query, "answers": answers} result = {"query": query, "answers": answers}
return result return result

View File

@ -36,20 +36,18 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. :param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param kwargs: Additional keyword arguments passed to the underlying model. :param kwargs: Additional keyword arguments passed to the underlying model.
[See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat). [See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
Note: additional model argument moderate_content will filter input and generated answers for potentially
sensitive content using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation)
if set. If the input or answers are flagged, an empty list is returned in place of the answers.
""" """
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs) super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
def invoke(self, *args, **kwargs): def _execute_openai_request(
self, prompt: Union[str, List[Dict]], base_payload: Dict, kwargs_with_defaults: Dict, stream: bool
):
""" """
It takes in either a prompt or a list of messages and returns a list of responses, using a REST invocation.
:return: A list of generated responses.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat). For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat).
""" """
prompt = kwargs.get("prompt", None)
if isinstance(prompt, str): if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict): elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
@ -60,34 +58,8 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. " f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)." f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
) )
extra_payload = {"messages": messages}
kwargs_with_defaults = self.model_input_kwargs payload = {**base_payload, **extra_payload}
if kwargs:
# we use keyword stop_words but OpenAI uses stop
if "stop_words" in kwargs:
kwargs["stop"] = kwargs.pop("stop_words")
if "top_k" in kwargs:
top_k = kwargs.pop("top_k")
kwargs["n"] = top_k
kwargs["best_of"] = top_k
kwargs_with_defaults.update(kwargs)
stream = (
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
)
payload = {
"model": self.model_name_or_path,
"messages": messages,
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
"temperature": kwargs_with_defaults.get("temperature", 0.7),
"top_p": kwargs_with_defaults.get("top_p", 1),
"n": kwargs_with_defaults.get("n", 1),
"stream": stream,
"stop": kwargs_with_defaults.get("stop", None),
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
}
if not stream: if not stream:
response = openai_request(url=self.url, headers=self.headers, payload=payload) response = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload) _check_openai_finish_reason(result=response, payload=payload)

View File

@ -10,6 +10,7 @@ from haystack.utils.openai_utils import (
_openai_text_completion_tokenization_details, _openai_text_completion_tokenization_details,
load_openai_tokenizer, load_openai_tokenizer,
_check_openai_finish_reason, _check_openai_finish_reason,
check_openai_policy_violation,
) )
from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
@ -47,6 +48,9 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
kwargs includes: suffix, temperature, top_p, presence_penalty, frequency_penalty, best_of, n, max_tokens, kwargs includes: suffix, temperature, top_p, presence_penalty, frequency_penalty, best_of, n, max_tokens,
logit_bias, stop, echo, and logprobs. For more details about these kwargs, see OpenAI logit_bias, stop, echo, and logprobs. For more details about these kwargs, see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/completions/create). [documentation](https://platform.openai.com/docs/api-reference/completions/create).
Note: additional model argument moderate_content will filter input and generated answers for potentially
sensitive content using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation)
if set. If the input or answers are flagged, an empty list is returned in place of the answers.
""" """
super().__init__(model_name_or_path) super().__init__(model_name_or_path)
if not isinstance(api_key, str) or len(api_key) == 0: if not isinstance(api_key, str) or len(api_key) == 0:
@ -81,6 +85,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
"logit_bias", "logit_bias",
"stream", "stream",
"stream_handler", "stream_handler",
"moderate_content",
] ]
if key in kwargs if key in kwargs
} }
@ -101,7 +106,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
def invoke(self, *args, **kwargs): def invoke(self, *args, **kwargs):
""" """
Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation. Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
and returns a list of responses using a REST invocation.
:return: The responses are being returned. :return: The responses are being returned.
@ -109,12 +115,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create). For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
""" """
prompt = kwargs.get("prompt") prompt = kwargs.get("prompt")
if not prompt: # either stream is True (will use default handler) or stream_handler is provided
raise ValueError(
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
kwargs_with_defaults = self.model_input_kwargs kwargs_with_defaults = self.model_input_kwargs
if kwargs: if kwargs:
# we use keyword stop_words but OpenAI uses stop # we use keyword stop_words but OpenAI uses stop
@ -125,28 +126,47 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
kwargs["n"] = top_k kwargs["n"] = top_k
kwargs["best_of"] = top_k kwargs["best_of"] = top_k
kwargs_with_defaults.update(kwargs) kwargs_with_defaults.update(kwargs)
# either stream is True (will use default handler) or stream_handler is provided
stream = ( stream = (
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
) )
payload = { moderation = kwargs_with_defaults.get("moderate_content", False)
base_payload = { # payload common to all OpenAI models
"model": self.model_name_or_path, "model": self.model_name_or_path,
"prompt": prompt,
"suffix": kwargs_with_defaults.get("suffix", None),
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length), "max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
"temperature": kwargs_with_defaults.get("temperature", 0.7), "temperature": kwargs_with_defaults.get("temperature", 0.7),
"top_p": kwargs_with_defaults.get("top_p", 1), "top_p": kwargs_with_defaults.get("top_p", 1),
"n": kwargs_with_defaults.get("n", 1), "n": kwargs_with_defaults.get("n", 1),
"stream": stream, "stream": stream,
"logprobs": kwargs_with_defaults.get("logprobs", None),
"echo": kwargs_with_defaults.get("echo", False),
"stop": kwargs_with_defaults.get("stop", None), "stop": kwargs_with_defaults.get("stop", None),
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0), "presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0), "frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
"best_of": kwargs_with_defaults.get("best_of", 1),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}), "logit_bias": kwargs_with_defaults.get("logit_bias", {}),
} }
if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
return []
responses = self._execute_openai_request(
prompt=prompt, base_payload=base_payload, kwargs_with_defaults=kwargs_with_defaults, stream=stream
)
if moderation and check_openai_policy_violation(input=responses, headers=self.headers):
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
return []
return responses
def _execute_openai_request(self, prompt: str, base_payload: Dict, kwargs_with_defaults: Dict, stream: bool):
if not prompt:
raise ValueError(
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
extra_payload = {
"prompt": prompt,
"suffix": kwargs_with_defaults.get("suffix", None),
"logprobs": kwargs_with_defaults.get("logprobs", None),
"echo": kwargs_with_defaults.get("echo", False),
"best_of": kwargs_with_defaults.get("best_of", 1),
}
payload = {**base_payload, **extra_payload}
if not stream: if not stream:
res = openai_request(url=self.url, headers=self.headers, payload=payload) res = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=res, payload=payload) _check_openai_finish_reason(result=res, payload=payload)

View File

@ -18,7 +18,6 @@ from haystack.environment import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
machine = platform.machine().lower() machine = platform.machine().lower()
system = platform.system() system = platform.system()
@ -26,6 +25,8 @@ OPENAI_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
OPENAI_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10)) OPENAI_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
OPENAI_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5)) OPENAI_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
OPENAI_MODERATION_URL = "https://api.openai.com/v1/moderations"
def load_openai_tokenizer(tokenizer_name: str): def load_openai_tokenizer(tokenizer_name: str):
"""Load either the tokenizer from tiktoken (if the library is available) or fallback to the GPT2TokenizerFast """Load either the tokenizer from tiktoken (if the library is available) or fallback to the GPT2TokenizerFast
@ -150,6 +151,26 @@ def openai_request(
return response return response
def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
"""
Calls the moderation endpoint to check if the text(s) violate the policy.
See [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation) for more details.
Returns true if any of the input is flagged as any of ['sexual', 'hate', 'violence', 'self-harm', 'sexual/minors', 'hate/threatening', 'violence/graphic'].
"""
response = openai_request(url=OPENAI_MODERATION_URL, headers=headers, payload={"input": input})
results = response["results"]
flagged = any(res["flagged"] for res in results)
if flagged:
for result in results:
if result["flagged"]:
logger.debug(
"OpenAI Moderation API flagged the text '%s' as a potential policy violation of the following categories: %s",
input,
result["categories"],
)
return flagged
def _check_openai_finish_reason(result: Dict, payload: Dict) -> None: def _check_openai_finish_reason(result: Dict, payload: Dict) -> None:
"""Check the `finish_reason` the answers returned by OpenAI completions endpoint. """Check the `finish_reason` the answers returned by OpenAI completions endpoint.
If the `finish_reason` is `length` or `content_filter`, log a warning to the user. If the `finish_reason` is `length` or `content_filter`, log a warning to the user.

View File

@ -1013,6 +1013,40 @@ def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
@pytest.mark.unit @pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
def test_content_moderation_gpt_3_and_gpt_3_5():
"""
Check all possible cases of the moderation checks passing / failing in a PromptNode call
for both ChatGPTInvocationLayer and OpenAIInvocationLayer.
"""
prompt_node_gpt_3_5 = PromptNode(
model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True}
)
prompt_node_gpt_3 = PromptNode(
model_name_or_path="text-davinci-003", api_key="key", model_kwargs={"moderate_content": True}
)
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
"haystack.nodes.prompt.invocation_layer.chatgpt.ChatGPTInvocationLayer._execute_openai_request"
) as mock_execute_gpt_3_5, patch(
"haystack.nodes.prompt.invocation_layer.open_ai.OpenAIInvocationLayer._execute_openai_request"
) as mock_execute_gpt_3:
VIOLENT_TEXT = "some violent text"
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
# case 1: prompt fails the moderation check
# prompt should not be sent to OpenAi & function should return an empty list
mock_check.return_value = True
assert prompt_node_gpt_3_5(VIOLENT_TEXT) == prompt_node_gpt_3(VIOLENT_TEXT) == []
# case 2: prompt passes the moderation check but the generated output fails the check
# function should also return an empty list
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = [VIOLENT_TEXT]
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == []
# case 3: both prompt and output pass the moderation check
# function should return the output
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = ["normal output"]
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == ["normal output"]
@patch("haystack.nodes.prompt.prompt_node.PromptModel") @patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_node_warns_about_missing_documents(mock_model, caplog): def test_prompt_node_warns_about_missing_documents(mock_model, caplog):
lfqa_prompt = PromptTemplate( lfqa_prompt = PromptTemplate(

View File

@ -1,3 +1,5 @@
import copy
import pytest import pytest
from unittest.mock import patch from unittest.mock import patch
@ -5,7 +7,11 @@ import pytest
from tenacity import wait_none from tenacity import wait_none
from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError
from haystack.utils.openai_utils import openai_request, _openai_text_completion_tokenization_details from haystack.utils.openai_utils import (
openai_request,
_openai_text_completion_tokenization_details,
check_openai_policy_violation,
)
@pytest.mark.unit @pytest.mark.unit
@ -98,3 +104,46 @@ def test_openai_request_does_not_retry_on_success(mock_requests):
openai_request.retry_with(wait=wait_none())(url="some_url", headers={}, payload={}, read_response=False) openai_request.retry_with(wait=wait_none())(url="some_url", headers={}, payload={}, read_response=False)
assert mock_requests.request.call_count == 1 assert mock_requests.request.call_count == 1
@pytest.mark.unit
def test_check_openai_policy_violation():
moderation_endpoint_mock_response_flagged = {
"id": "modr-7Ok9zndoeSn5ij654vuNCgFVomU4U",
"model": "text-moderation-004",
"results": [
{
"flagged": True,
"categories": {
"sexual": False,
"hate": False,
"violence": True,
"self-harm": True,
"sexual/minors": False,
"hate/threatening": False,
"violence/graphic": False,
},
"category_scores": {
"sexual": 2.6659495e-06,
"hate": 1.9359974e-05,
"violence": 0.95964026,
"self-harm": 0.9696306,
"sexual/minors": 4.1061935e-07,
"hate/threatening": 4.9856953e-07,
"violence/graphic": 0.2683866,
},
}
],
}
moderation_endpoint_mock_response_not_flagged = copy.deepcopy(moderation_endpoint_mock_response_flagged)
moderation_endpoint_mock_response_not_flagged["results"][0]["flagged"] = False
moderation_endpoint_mock_response_not_flagged["results"][0]["categories"].update(
{"violence": False, "self-harm": False}
)
with patch("haystack.utils.openai_utils.openai_request") as mock_openai_request:
# check that the function returns True if the input is flagged
mock_openai_request.return_value = moderation_endpoint_mock_response_flagged
assert check_openai_policy_violation(input="violent input", headers={}) == True
# check that the function returns False if the input is not flagged
mock_openai_request.return_value = moderation_endpoint_mock_response_not_flagged
assert check_openai_policy_violation(input="ok input", headers={}) == False