mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 02:39:51 +00:00
feat: Optional Content Moderation for OpenAI PromptNode & OpenAIAnswerGenerator (#5017)
* #4071 implemented optional content moderation for OpenAI PromptNode * added two simple integration tests * improved documentation & renamed _invoke method to _execute_openai_request * added a flag to check_openai_policy_violation that will return a full dict of all text violations and their categories * re-implemented the tests as unit tests & without use of the OpenAI APIs * removed unused patch * changed check_openai_policy_violation back to only return a bool * fixed pylint and test error --------- Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
97f136b901
commit
1318ac5074
@ -11,6 +11,7 @@ from haystack.utils.openai_utils import (
|
||||
openai_request,
|
||||
_openai_text_completion_tokenization_details,
|
||||
_check_openai_finish_reason,
|
||||
check_openai_policy_violation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -45,6 +46,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
progress_bar: bool = True,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
context_join_str: str = " ",
|
||||
moderate_content: bool = False,
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
):
|
||||
"""
|
||||
@ -99,6 +101,9 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
|
||||
:param context_join_str: The separation string used to join the input documents to create the context
|
||||
used by the PromptTemplate.
|
||||
:param moderate_content: Whether to filter input and generated answers for potentially sensitive content
|
||||
using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation). If the input or
|
||||
answers are flagged, an empty list is returned in place of the answers.
|
||||
:param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`.
|
||||
"""
|
||||
super().__init__(progress_bar=progress_bar)
|
||||
@ -159,6 +164,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
self.prompt_template = prompt_template
|
||||
self.context_join_str = context_join_str
|
||||
self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None
|
||||
self.moderate_content = moderate_content
|
||||
|
||||
tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model)
|
||||
|
||||
@ -228,9 +234,19 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
if self.moderate_content and check_openai_policy_violation(input=prompt, headers=headers):
|
||||
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||
return {"query": query, "answers": []}
|
||||
|
||||
logger.debug("Prompt being sent to OpenAI API with prompt %s.", prompt)
|
||||
res = openai_request(url=url, headers=headers, payload=payload, timeout=timeout)
|
||||
_check_openai_finish_reason(result=res, payload=payload)
|
||||
generated_answers = [ans["text"] for ans in res["choices"]]
|
||||
if self.moderate_content and check_openai_policy_violation(input=generated_answers, headers=headers):
|
||||
logger.info(
|
||||
"Generated answers '%s' will not be returned due to potential policy violation.", generated_answers
|
||||
)
|
||||
return {"query": query, "answers": []}
|
||||
answers = self._create_answers(generated_answers, input_docs, prompt=prompt)
|
||||
result = {"query": query, "answers": answers}
|
||||
return result
|
||||
|
||||
@ -36,20 +36,18 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
||||
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
||||
:param kwargs: Additional keyword arguments passed to the underlying model.
|
||||
[See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
|
||||
Note: additional model argument moderate_content will filter input and generated answers for potentially
|
||||
sensitive content using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation)
|
||||
if set. If the input or answers are flagged, an empty list is returned in place of the answers.
|
||||
"""
|
||||
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
def _execute_openai_request(
|
||||
self, prompt: Union[str, List[Dict]], base_payload: Dict, kwargs_with_defaults: Dict, stream: bool
|
||||
):
|
||||
"""
|
||||
It takes in either a prompt or a list of messages and returns a list of responses, using a REST invocation.
|
||||
|
||||
:return: A list of generated responses.
|
||||
|
||||
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
|
||||
For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat).
|
||||
"""
|
||||
prompt = kwargs.get("prompt", None)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
|
||||
@ -60,34 +58,8 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
||||
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
|
||||
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
|
||||
)
|
||||
|
||||
kwargs_with_defaults = self.model_input_kwargs
|
||||
if kwargs:
|
||||
# we use keyword stop_words but OpenAI uses stop
|
||||
if "stop_words" in kwargs:
|
||||
kwargs["stop"] = kwargs.pop("stop_words")
|
||||
if "top_k" in kwargs:
|
||||
top_k = kwargs.pop("top_k")
|
||||
kwargs["n"] = top_k
|
||||
kwargs["best_of"] = top_k
|
||||
kwargs_with_defaults.update(kwargs)
|
||||
|
||||
stream = (
|
||||
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
||||
)
|
||||
payload = {
|
||||
"model": self.model_name_or_path,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
|
||||
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
||||
"top_p": kwargs_with_defaults.get("top_p", 1),
|
||||
"n": kwargs_with_defaults.get("n", 1),
|
||||
"stream": stream,
|
||||
"stop": kwargs_with_defaults.get("stop", None),
|
||||
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
|
||||
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
||||
}
|
||||
extra_payload = {"messages": messages}
|
||||
payload = {**base_payload, **extra_payload}
|
||||
if not stream:
|
||||
response = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_finish_reason(result=response, payload=payload)
|
||||
|
||||
@ -10,6 +10,7 @@ from haystack.utils.openai_utils import (
|
||||
_openai_text_completion_tokenization_details,
|
||||
load_openai_tokenizer,
|
||||
_check_openai_finish_reason,
|
||||
check_openai_policy_violation,
|
||||
)
|
||||
from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
|
||||
@ -47,6 +48,9 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
kwargs includes: suffix, temperature, top_p, presence_penalty, frequency_penalty, best_of, n, max_tokens,
|
||||
logit_bias, stop, echo, and logprobs. For more details about these kwargs, see OpenAI
|
||||
[documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
||||
Note: additional model argument moderate_content will filter input and generated answers for potentially
|
||||
sensitive content using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation)
|
||||
if set. If the input or answers are flagged, an empty list is returned in place of the answers.
|
||||
"""
|
||||
super().__init__(model_name_or_path)
|
||||
if not isinstance(api_key, str) or len(api_key) == 0:
|
||||
@ -81,6 +85,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
"logit_bias",
|
||||
"stream",
|
||||
"stream_handler",
|
||||
"moderate_content",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
@ -101,7 +106,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation.
|
||||
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
|
||||
and returns a list of responses using a REST invocation.
|
||||
|
||||
:return: The responses are being returned.
|
||||
|
||||
@ -109,12 +115,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
|
||||
"""
|
||||
prompt = kwargs.get("prompt")
|
||||
if not prompt:
|
||||
raise ValueError(
|
||||
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
||||
f"Make sure to provide prompt in kwargs."
|
||||
)
|
||||
|
||||
# either stream is True (will use default handler) or stream_handler is provided
|
||||
kwargs_with_defaults = self.model_input_kwargs
|
||||
if kwargs:
|
||||
# we use keyword stop_words but OpenAI uses stop
|
||||
@ -125,28 +126,47 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
kwargs["n"] = top_k
|
||||
kwargs["best_of"] = top_k
|
||||
kwargs_with_defaults.update(kwargs)
|
||||
|
||||
# either stream is True (will use default handler) or stream_handler is provided
|
||||
stream = (
|
||||
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
||||
)
|
||||
payload = {
|
||||
moderation = kwargs_with_defaults.get("moderate_content", False)
|
||||
base_payload = { # payload common to all OpenAI models
|
||||
"model": self.model_name_or_path,
|
||||
"prompt": prompt,
|
||||
"suffix": kwargs_with_defaults.get("suffix", None),
|
||||
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
|
||||
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
||||
"top_p": kwargs_with_defaults.get("top_p", 1),
|
||||
"n": kwargs_with_defaults.get("n", 1),
|
||||
"stream": stream,
|
||||
"logprobs": kwargs_with_defaults.get("logprobs", None),
|
||||
"echo": kwargs_with_defaults.get("echo", False),
|
||||
"stop": kwargs_with_defaults.get("stop", None),
|
||||
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
|
||||
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
||||
"best_of": kwargs_with_defaults.get("best_of", 1),
|
||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
||||
}
|
||||
if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
|
||||
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||
return []
|
||||
responses = self._execute_openai_request(
|
||||
prompt=prompt, base_payload=base_payload, kwargs_with_defaults=kwargs_with_defaults, stream=stream
|
||||
)
|
||||
if moderation and check_openai_policy_violation(input=responses, headers=self.headers):
|
||||
logger.info("Response '%s' will not be returned due to potential policy violation.", responses)
|
||||
return []
|
||||
return responses
|
||||
|
||||
def _execute_openai_request(self, prompt: str, base_payload: Dict, kwargs_with_defaults: Dict, stream: bool):
|
||||
if not prompt:
|
||||
raise ValueError(
|
||||
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
||||
f"Make sure to provide prompt in kwargs."
|
||||
)
|
||||
extra_payload = {
|
||||
"prompt": prompt,
|
||||
"suffix": kwargs_with_defaults.get("suffix", None),
|
||||
"logprobs": kwargs_with_defaults.get("logprobs", None),
|
||||
"echo": kwargs_with_defaults.get("echo", False),
|
||||
"best_of": kwargs_with_defaults.get("best_of", 1),
|
||||
}
|
||||
payload = {**base_payload, **extra_payload}
|
||||
if not stream:
|
||||
res = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_finish_reason(result=res, payload=payload)
|
||||
|
||||
@ -18,7 +18,6 @@ from haystack.environment import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
machine = platform.machine().lower()
|
||||
system = platform.system()
|
||||
|
||||
@ -26,6 +25,8 @@ OPENAI_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
|
||||
OPENAI_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
|
||||
OPENAI_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
|
||||
|
||||
OPENAI_MODERATION_URL = "https://api.openai.com/v1/moderations"
|
||||
|
||||
|
||||
def load_openai_tokenizer(tokenizer_name: str):
|
||||
"""Load either the tokenizer from tiktoken (if the library is available) or fallback to the GPT2TokenizerFast
|
||||
@ -150,6 +151,26 @@ def openai_request(
|
||||
return response
|
||||
|
||||
|
||||
def check_openai_policy_violation(input: Union[List[str], str], headers: Dict) -> bool:
|
||||
"""
|
||||
Calls the moderation endpoint to check if the text(s) violate the policy.
|
||||
See [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation) for more details.
|
||||
Returns true if any of the input is flagged as any of ['sexual', 'hate', 'violence', 'self-harm', 'sexual/minors', 'hate/threatening', 'violence/graphic'].
|
||||
"""
|
||||
response = openai_request(url=OPENAI_MODERATION_URL, headers=headers, payload={"input": input})
|
||||
results = response["results"]
|
||||
flagged = any(res["flagged"] for res in results)
|
||||
if flagged:
|
||||
for result in results:
|
||||
if result["flagged"]:
|
||||
logger.debug(
|
||||
"OpenAI Moderation API flagged the text '%s' as a potential policy violation of the following categories: %s",
|
||||
input,
|
||||
result["categories"],
|
||||
)
|
||||
return flagged
|
||||
|
||||
|
||||
def _check_openai_finish_reason(result: Dict, payload: Dict) -> None:
|
||||
"""Check the `finish_reason` the answers returned by OpenAI completions endpoint.
|
||||
If the `finish_reason` is `length` or `content_filter`, log a warning to the user.
|
||||
|
||||
@ -1013,6 +1013,40 @@ def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer", lambda tokenizer_name: None)
|
||||
@patch("haystack.nodes.prompt.prompt_model.PromptModel._ensure_token_limit", lambda self, prompt: prompt)
|
||||
def test_content_moderation_gpt_3_and_gpt_3_5():
|
||||
"""
|
||||
Check all possible cases of the moderation checks passing / failing in a PromptNode call
|
||||
for both ChatGPTInvocationLayer and OpenAIInvocationLayer.
|
||||
"""
|
||||
prompt_node_gpt_3_5 = PromptNode(
|
||||
model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True}
|
||||
)
|
||||
prompt_node_gpt_3 = PromptNode(
|
||||
model_name_or_path="text-davinci-003", api_key="key", model_kwargs={"moderate_content": True}
|
||||
)
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
|
||||
"haystack.nodes.prompt.invocation_layer.chatgpt.ChatGPTInvocationLayer._execute_openai_request"
|
||||
) as mock_execute_gpt_3_5, patch(
|
||||
"haystack.nodes.prompt.invocation_layer.open_ai.OpenAIInvocationLayer._execute_openai_request"
|
||||
) as mock_execute_gpt_3:
|
||||
VIOLENT_TEXT = "some violent text"
|
||||
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
|
||||
# case 1: prompt fails the moderation check
|
||||
# prompt should not be sent to OpenAi & function should return an empty list
|
||||
mock_check.return_value = True
|
||||
assert prompt_node_gpt_3_5(VIOLENT_TEXT) == prompt_node_gpt_3(VIOLENT_TEXT) == []
|
||||
# case 2: prompt passes the moderation check but the generated output fails the check
|
||||
# function should also return an empty list
|
||||
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = [VIOLENT_TEXT]
|
||||
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == []
|
||||
# case 3: both prompt and output pass the moderation check
|
||||
# function should return the output
|
||||
mock_execute_gpt_3_5.return_value = mock_execute_gpt_3.return_value = ["normal output"]
|
||||
assert prompt_node_gpt_3_5("normal prompt") == prompt_node_gpt_3("normal prompt") == ["normal output"]
|
||||
|
||||
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_node_warns_about_missing_documents(mock_model, caplog):
|
||||
lfqa_prompt = PromptTemplate(
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -5,7 +7,11 @@ import pytest
|
||||
from tenacity import wait_none
|
||||
|
||||
from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError
|
||||
from haystack.utils.openai_utils import openai_request, _openai_text_completion_tokenization_details
|
||||
from haystack.utils.openai_utils import (
|
||||
openai_request,
|
||||
_openai_text_completion_tokenization_details,
|
||||
check_openai_policy_violation,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -98,3 +104,46 @@ def test_openai_request_does_not_retry_on_success(mock_requests):
|
||||
openai_request.retry_with(wait=wait_none())(url="some_url", headers={}, payload={}, read_response=False)
|
||||
|
||||
assert mock_requests.request.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_check_openai_policy_violation():
|
||||
moderation_endpoint_mock_response_flagged = {
|
||||
"id": "modr-7Ok9zndoeSn5ij654vuNCgFVomU4U",
|
||||
"model": "text-moderation-004",
|
||||
"results": [
|
||||
{
|
||||
"flagged": True,
|
||||
"categories": {
|
||||
"sexual": False,
|
||||
"hate": False,
|
||||
"violence": True,
|
||||
"self-harm": True,
|
||||
"sexual/minors": False,
|
||||
"hate/threatening": False,
|
||||
"violence/graphic": False,
|
||||
},
|
||||
"category_scores": {
|
||||
"sexual": 2.6659495e-06,
|
||||
"hate": 1.9359974e-05,
|
||||
"violence": 0.95964026,
|
||||
"self-harm": 0.9696306,
|
||||
"sexual/minors": 4.1061935e-07,
|
||||
"hate/threatening": 4.9856953e-07,
|
||||
"violence/graphic": 0.2683866,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
moderation_endpoint_mock_response_not_flagged = copy.deepcopy(moderation_endpoint_mock_response_flagged)
|
||||
moderation_endpoint_mock_response_not_flagged["results"][0]["flagged"] = False
|
||||
moderation_endpoint_mock_response_not_flagged["results"][0]["categories"].update(
|
||||
{"violence": False, "self-harm": False}
|
||||
)
|
||||
with patch("haystack.utils.openai_utils.openai_request") as mock_openai_request:
|
||||
# check that the function returns True if the input is flagged
|
||||
mock_openai_request.return_value = moderation_endpoint_mock_response_flagged
|
||||
assert check_openai_policy_violation(input="violent input", headers={}) == True
|
||||
# check that the function returns False if the input is not flagged
|
||||
mock_openai_request.return_value = moderation_endpoint_mock_response_not_flagged
|
||||
assert check_openai_policy_violation(input="ok input", headers={}) == False
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user