fix: StopWordsCriteria doesn't compare the stop word token ids with the input ids in a continuous and sequential order (#5503)

* bug fix

* add release note

* add unit test

* refactor

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
Fanli Lin 2023-08-08 14:35:10 +08:00 committed by GitHub
parent 99cb95a63a
commit f6b50cfdf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 2 deletions

View File

@ -51,8 +51,18 @@ with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_a
self.stop_words = encoded_stop_words.input_ids.to(device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_result = torch.isin(self.stop_words, input_ids[-1])
return any(all(stop_word) for stop_word in stop_result)
for stop_word in self.stop_words:
found_stop_word = self.is_stop_word_found(input_ids, stop_word)
if found_stop_word:
return True
return False
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_word: torch.Tensor) -> bool:
generated_text_ids = generated_text_ids[-1]
len_generated_text_ids = generated_text_ids.size(0)
len_stop_word = stop_word.size(0)
result = all(generated_text_ids[len_generated_text_ids - len_stop_word :].eq(stop_word))
return result
class HFLocalInvocationLayer(PromptModelInvocationLayer):

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fix StopWordsCriteria not checking stop word tokens in a continuous and sequential order

View File

@ -468,6 +468,44 @@ def test_stop_words_multiple_token(stop_words: List[str]):
assert "health" not in result[0]
@pytest.mark.unit
def test_stop_words_criteria():
"""
Test that StopWordsCriteria will check stop word tokens in a continuous and sequential order
"""
# input ids for "unambiguously"
stop_words_id = torch.tensor([[73, 24621, 11937]])
# input ids for "This is ambiguously, but is unrelated."
input_ids1 = torch.tensor([[100, 19, 24621, 11937, 6, 68, 19, 73, 3897, 5]])
# input ids for "This is unambiguously"
input_ids2 = torch.tensor([[100, 19, 73, 24621, 11937]])
# We used to implement stop words algorithm using the torch.isin function like this:
# `all(torch.isin(stop_words_id, input_ids1)[0])`
# However, this algorithm is not correct as it will return True for presence of "unambiguously" in input_ids1
# and True for presence of "unambiguously" in input_ids2. This is because the algorithm will check
# if the stop word tokens are present in the input_ids, but it does not check if the stop word tokens are
# present in a continuous/sequential order.
# In "This is ambiguously, but is unrelated." sentence the "un" token comes from "unrelated" and the
# "ambiguously" token comes from "ambiguously". The algorithm will return True for presence of
# "unambiguously" in input_ids1 which is not correct.
stop_words_criteria = StopWordsCriteria(tokenizer=Mock(), stop_words=["mock data"])
# because we are mocking the tokenizer, we need to set the stop words manually
stop_words_criteria.stop_words = stop_words_id
# this is the correct algorithm to check if the stop word tokens are present in a continuous and sequential order
# For the input_ids1, the stop word tokens are present BUT not in a continuous order
present_and_continuous = stop_words_criteria(input_ids1, scores=None)
assert not present_and_continuous
# For the input_ids2, the stop word tokens are both present and in a continuous order
present_and_continuous = stop_words_criteria(input_ids2, scores=None)
assert present_and_continuous
@pytest.mark.integration
@pytest.mark.parametrize("stop_words", [["Berlin"], ["Berlin", "Brandenburg"], ["Berlin", "Brandenburg", "Germany"]])
def test_stop_words_not_being_found(stop_words: List[str]):