Fix HF stop words (single stop word) (#4584)

This commit is contained in:
Vladimir Blagojevic 2023-04-04 14:45:10 +02:00 committed by GitHub
parent ce61eda970
commit a8d283cfac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 4 deletions

View File

@ -125,12 +125,17 @@ class StopWordsCriteria(StoppingCriteria):
Stops text generation if any one of the stop words is generated.
"""
def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], stop_words: List[str]):
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stop_words: List[str],
device: Union[str, torch.device] = "cpu",
):
super().__init__()
self.stop_words = tokenizer.encode(stop_words, add_special_tokens=False, return_tensors="pt")
self.stop_words = tokenizer(stop_words, add_special_tokens=False, return_tensors="pt").to(device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(torch.isin(input_ids[-1], self.stop_words[-1]))
return any(torch.isin(input_ids[-1], self.stop_words["input_ids"]))
class HFLocalInvocationLayer(PromptModelInvocationLayer):
@ -268,7 +273,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
model_input_kwargs["return_full_text"] = False
model_input_kwargs["max_new_tokens"] = self.max_length
if stop_words:
sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words)
sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words, device=self.pipe.device)
model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw])
if top_k:
model_input_kwargs["num_return_sequences"] = top_k

View File

@ -237,6 +237,14 @@ def test_open_ai_warn_if_max_tokens_is_too_short(prompt_model, caplog):
def test_stop_words(prompt_model):
skip_test_for_invalid_key(prompt_model)
# test single stop word for both HF and OpenAI
# set stop words in PromptNode
node = PromptNode(prompt_model, stop_words=["capital"])
# with default prompt template and stop words set in PN
r = node.prompt("question-generation", documents=["Berlin is the capital of Germany."])
assert r[0] == "What is the" or r[0] == "What city is the"
# test stop words for both HF and OpenAI
# set stop words in PromptNode
node = PromptNode(prompt_model, stop_words=["capital", "Germany"])