feat: Optional Content Moderation for OpenAI PromptNode & OpenAIAnswerGenerator (#5017)

* #4071 implemented optional content moderation for OpenAI PromptNode

* added two simple integration tests

* improved documentation & renamed _invoke method to _execute_openai_request

* added a flag to check_openai_policy_violation that will return a full dict of all text violations and their categories

* re-implemented the tests as unit tests & without use of the OpenAI APIs

* removed unused patch

* changed check_openai_policy_violation back to only return a bool

* fixed pylint and test error

---------

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
Ben Heckmann 2023-06-19 13:27:11 +02:00 committed by GitHub
parent 97f136b901
commit 1318ac5074
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 165 additions and 53 deletions

View File

@ -11,6 +11,7 @@ from haystack.utils.openai_utils import (
openai_request,
_openai_text_completion_tokenization_details,
_check_openai_finish_reason,
check_openai_policy_violation,
)
logger = logging.getLogger(__name__)
@ -45,6 +46,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
progress_bar: bool = True,
prompt_template: Optional[PromptTemplate] = None,
context_join_str: str = " ",
moderate_content: bool = False,
api_base: str = "https://api.openai.com/v1",
):
"""
@ -99,6 +101,9 @@ class OpenAIAnswerGenerator(BaseGenerator):
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
:param context_join_str: The separation string used to join the input documents to create the context
used by the PromptTemplate.
:param moderate_content: Whether to filter input and generated answers for potentially sensitive content
using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation). If the input or
answers are flagged, an empty list is returned in place of the answers.
:param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`.
"""
super().__init__(progress_bar=progress_bar)
@ -159,6 +164,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
self.prompt_template = prompt_template
self.context_join_str = context_join_str
self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None
self.moderate_content = moderate_content
tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model)
@ -228,9 +234,19 @@ class OpenAIAnswerGenerator(BaseGenerator):
else:
headers["Authorization"] = f"Bearer {self.api_key}"
if self.moderate_content and check_openai_policy_violation(input=prompt, headers=headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
return {"query": query, "answers": []}
logger.debug("Prompt being sent to OpenAI API with prompt %s.", prompt)
res = openai_request(url=url, headers=headers, payload=payload, timeout=timeout)
_check_openai_finish_reason(result=res, payload=payload)
generated_answers = [ans["text"] for ans in res["choices"]]
if self.moderate_content and check_openai_policy_violation(input=generated_answers, headers=headers):
logger.info(
"Generated answers '%s' will not be returned due to potential policy violation.", generated_answers
)
return {"query": query, "answers": []}
answers = self._create_answers(generated_answers, input_docs, prompt=prompt)
result = {"query": query, "answers": answers}
return result

View File

@ -36,20 +36,18 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param kwargs: Additional keyword arguments passed to the underlying model.
[See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
Note: additional model argument moderate_content will filter input and generated answers for potentially
sensitive content using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation)
if set. If the input or answers are flagged, an empty list is returned in place of the answers.
"""
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
def invoke(self, *args, **kwargs):
def _execute_openai_request(
self, prompt: Union[str, List[Dict]], base_payload: Dict, kwargs_with_defaults: Dict, stream: bool
):
"""
It takes in either a prompt or a list of messages and returns a list of responses, using a REST invocation.
:return: A list of generated responses.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat).
"""
prompt = kwargs.get("prompt", None)
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
@ -60,34 +58,8 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
)
kwargs_with_defaults = self.model_input_kwargs
if kwargs:
# we use keyword stop_words but OpenAI uses stop
if "stop_words" in kwargs:
kwargs["stop"] = kwargs.pop("stop_words")
if "top_k" in kwargs:
top_k = kwargs.pop("top_k")
kwargs["n"] = top_k
kwargs["best_of"] = top_k
kwargs_with_defaults.update(kwargs)
stream = (
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
)
payload = {
"model": self.model_name_or_path,
"messages": messages,
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
"temperature": kwargs_with_defaults.get("temperature", 0.7),
"top_p": kwargs_with_defaults.get("top_p", 1),
"n": kwargs_with_defaults.get("n", 1),
"stream": stream,
"stop": kwargs_with_defaults.get("stop", None),
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
}
extra_payload = {"messages": messages}
payload = {**base_payload, **extra_payload}
if not stream:
response = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload)

View File

@ -10,6 +10,7 @@ from haystack.utils.openai_utils import (
_openai_text_completion_tokenization_details,
load_openai_tokenizer,
_check_openai_finish_reason,
check_openai_policy_violation,
)
from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
@ -47,6 +48,9 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
kwargs includes: suffix, temperature, top_p, presence_penalty, frequency_penalty, best_of, n, max_tokens,
logit_bias, stop, echo, and logprobs. For more details about these kwargs, see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/completions/create).
Note: additional model argument moderate_content will filter input and generated answers for potentially
sensitive content using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation)
if set. If the input or answers are flagged, an empty list is returned in place of the answers.
"""
super().__init__(model_name_or_path)
if not isinstance(api_key, str) or len(api_key) == 0:
@ -81,6 +85,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
"logit_bias",
"stream",
"stream_handler",
"moderate_content",
]
if key in kwargs
}
@ -101,7 +106,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
def invoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation.
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
and returns a list of responses using a REST invocation.
:return: The responses are being returned.
@ -109,12 +115,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
prompt = kwargs.get("prompt")
if not prompt:
raise ValueError(
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
# either stream is True (will use default handler) or stream_handler is provided
kwargs_with_defaults = self.model_input_kwargs
if kwargs:
# we use keyword stop_words but OpenAI uses stop
@ -125,28 +126,47 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
kwargs["n"] = top_k
kwargs["best_of"] = top_k
kwargs_with_defaults.update(kwargs)
# either stream is True (will use default handler) or stream_handler is provided
stream = (
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
)
payload = {
moderation = kwargs_with_defaults.get("moderate_content", False)
base_payload = { # payload common to all OpenAI models
"model": self.model_name_or_path,
"prompt": prompt,
"suffix": kwargs_with_defaults.get("suffix", None),
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
"temperature": kwargs_with_defaults.get("temperature", 0.7),
"top_p": kwargs_with_defaults.get("top_p", 1),
"n": kwargs_with_defaults.get("n", 1),
"stream": stream,
"logprobs": kwargs_with_defaults.get("logprobs", None),
"echo": kwargs_with_defaults.get("echo", False),
"stop": kwargs_with_defaults.get("stop", None),
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
"best_of": kwargs_with_defaults.get("best_of", 1),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
}
if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
return []
responses = self._execute_openai_request(
prompt=prompt, base_payload=base_payload, kwargs_with_defaults=kwargs_with_defaults, stream=stream
)
if moderation and check_openai_policy_violation(input=responses, headers=self.headers):
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
return []
return responses
def _execute_openai_request(self, prompt: str, base_payload: Dict, kwargs_with_defaults: Dict, stream: bool):
if not prompt:
raise ValueError(
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
extra_payload = {
"prompt": prompt,
"suffix": kwargs_with_defaults.get("suffix", None),
"logprobs": kwargs_with_defaults.get("logprobs", None),
"echo": kwargs_with_defaults.get("echo", False),
"best_of": kwargs_with_defaults.get("best_of", 1),
}
payload = {**base_payload, **extra_payload}
if not stream:
res = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=res, payload=payload)

View File

@ -18,7 +18,6 @@ from haystack.environment import (
logger = logging.getLogger(__name__)
machine = platform.machine().lower()
system = platform.system()
@ -26,6 +25,8 @@ OPENAI_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
OPENAI_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
OPENAI_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
OPENAI_MODERATION_URL = "https://api.openai.com/v1/moderations"
def load_openai_tokenizer(tokenizer_name: str):
"""Load either the tokenizer from tiktoken (if the library is available) or fallback to the GPT2TokenizerFast
@ -150,6 +151,26 @@ def openai_request(
return response
def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
"""
Calls the moderation endpoint to check if the text(s) violate the policy.
See [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation) for more details.
Returns true if any of the input is flagged as any of ['sexual', 'hate', 'violence', 'self-harm', 'sexual/minors', 'hate/threatening', 'violence/graphic'].
"""
response = openai_request(url=OPENAI_MODERATION_URL, headers=headers, payload={"input": input})
results = response["results"]
flagged = any(res["flagged"] for res in results)
if flagged:
for result in results:
if result["flagged"]:
logger.debug(
"OpenAI Moderation API flagged the text '%s' as a potential policy violation of the following categories: %s",
input,
result["categories"],
)
return flagged
def _check_openai_finish_reason(result: Dict, payload: Dict) -> None:
"""Check the `finish_reason` the answers returned by OpenAI completions endpoint.
If the `finish_reason` is `length` or `content_filter`, log a warning to the user.

View File

@ -1013,6 +1013,40 @@ def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
def test_content_moderation_gpt_3_and_gpt_3_5():
"""
Check all possible cases of the moderation checks passing / failing in a PromptNode call
for both ChatGPTInvocationLayer and OpenAIInvocationLayer.
"""
prompt_node_gpt_3_5 = PromptNode(
model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True}
)
prompt_node_gpt_3 = PromptNode(
model_name_or_path="text-davinci-003", api_key="key", model_kwargs={"moderate_content": True}
)
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
"haystack.nodes.prompt.invocation_layer.chatgpt.ChatGPTInvocationLayer._execute_openai_request"
) as mock_execute_gpt_3_5, patch(
"haystack.nodes.prompt.invocation_layer.open_ai.OpenAIInvocationLayer._execute_openai_request"
) as mock_execute_gpt_3:
VIOLENT_TEXT = "some violent text"
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
# case 1: prompt fails the moderation check
# prompt should not be sent to OpenAi & function should return an empty list
mock_check.return_value = True
assert prompt_node_gpt_3_5(VIOLENT_TEXT) == prompt_node_gpt_3(VIOLENT_TEXT) == []
# case 2: prompt passes the moderation check but the generated output fails the check
# function should also return an empty list
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = [VIOLENT_TEXT]
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == []
# case 3: both prompt and output pass the moderation check
# function should return the output
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = ["normal output"]
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == ["normal output"]
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_node_warns_about_missing_documents(mock_model, caplog):
lfqa_prompt = PromptTemplate(

View File

@ -1,3 +1,5 @@
import copy
import pytest
from unittest.mock import patch
@ -5,7 +7,11 @@ import pytest
from tenacity import wait_none
from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError
from haystack.utils.openai_utils import openai_request, _openai_text_completion_tokenization_details
from haystack.utils.openai_utils import (
openai_request,
_openai_text_completion_tokenization_details,
check_openai_policy_violation,
)
@pytest.mark.unit
@ -98,3 +104,46 @@ def test_openai_request_does_not_retry_on_success(mock_requests):
openai_request.retry_with(wait=wait_none())(url="some_url", headers={}, payload={}, read_response=False)
assert mock_requests.request.call_count == 1
@pytest.mark.unit
def test_check_openai_policy_violation():
moderation_endpoint_mock_response_flagged = {
"id": "modr-7Ok9zndoeSn5ij654vuNCgFVomU4U",
"model": "text-moderation-004",
"results": [
{
"flagged": True,
"categories": {
"sexual": False,
"hate": False,
"violence": True,
"self-harm": True,
"sexual/minors": False,
"hate/threatening": False,
"violence/graphic": False,
},
"category_scores": {
"sexual": 2.6659495e-06,
"hate": 1.9359974e-05,
"violence": 0.95964026,
"self-harm": 0.9696306,
"sexual/minors": 4.1061935e-07,
"hate/threatening": 4.9856953e-07,
"violence/graphic": 0.2683866,
},
}
],
}
moderation_endpoint_mock_response_not_flagged = copy.deepcopy(moderation_endpoint_mock_response_flagged)
moderation_endpoint_mock_response_not_flagged["results"][0]["flagged"] = False
moderation_endpoint_mock_response_not_flagged["results"][0]["categories"].update(
{"violence": False, "self-harm": False}
)
with patch("haystack.utils.openai_utils.openai_request") as mock_openai_request:
# check that the function returns True if the input is flagged
mock_openai_request.return_value = moderation_endpoint_mock_response_flagged
assert check_openai_policy_violation(input="violent input", headers={}) == True
# check that the function returns False if the input is not flagged
mock_openai_request.return_value = moderation_endpoint_mock_response_not_flagged
assert check_openai_policy_violation(input="ok input", headers={}) == False