mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
fix: StopWordsCriteria doesn't compare the stop word token ids with the input ids in a continuous and sequential order (#5503)
* bug fix * add release note * add unit test * refactor --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
parent
99cb95a63a
commit
f6b50cfdf9
@ -51,8 +51,18 @@ with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_a
|
||||
self.stop_words = encoded_stop_words.input_ids.to(device)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
stop_result = torch.isin(self.stop_words, input_ids[-1])
|
||||
return any(all(stop_word) for stop_word in stop_result)
|
||||
for stop_word in self.stop_words:
|
||||
found_stop_word = self.is_stop_word_found(input_ids, stop_word)
|
||||
if found_stop_word:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_word: torch.Tensor) -> bool:
|
||||
generated_text_ids = generated_text_ids[-1]
|
||||
len_generated_text_ids = generated_text_ids.size(0)
|
||||
len_stop_word = stop_word.size(0)
|
||||
result = all(generated_text_ids[len_generated_text_ids - len_stop_word :].eq(stop_word))
|
||||
return result
|
||||
|
||||
|
||||
class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Fix StopWordsCriteria not checking stop word tokens in a continuous and sequential order
|
||||
@ -468,6 +468,44 @@ def test_stop_words_multiple_token(stop_words: List[str]):
|
||||
assert "health" not in result[0]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_stop_words_criteria():
|
||||
"""
|
||||
Test that StopWordsCriteria will check stop word tokens in a continuous and sequential order
|
||||
"""
|
||||
# input ids for "unambiguously"
|
||||
stop_words_id = torch.tensor([[73, 24621, 11937]])
|
||||
|
||||
# input ids for "This is ambiguously, but is unrelated."
|
||||
input_ids1 = torch.tensor([[100, 19, 24621, 11937, 6, 68, 19, 73, 3897, 5]])
|
||||
# input ids for "This is unambiguously"
|
||||
input_ids2 = torch.tensor([[100, 19, 73, 24621, 11937]])
|
||||
|
||||
# We used to implement stop words algorithm using the torch.isin function like this:
|
||||
# `all(torch.isin(stop_words_id, input_ids1)[0])`
|
||||
# However, this algorithm is not correct as it will return True for presence of "unambiguously" in input_ids1
|
||||
# and True for presence of "unambiguously" in input_ids2. This is because the algorithm will check
|
||||
# if the stop word tokens are present in the input_ids, but it does not check if the stop word tokens are
|
||||
# present in a continuous/sequential order.
|
||||
|
||||
# In "This is ambiguously, but is unrelated." sentence the "un" token comes from "unrelated" and the
|
||||
# "ambiguously" token comes from "ambiguously". The algorithm will return True for presence of
|
||||
# "unambiguously" in input_ids1 which is not correct.
|
||||
|
||||
stop_words_criteria = StopWordsCriteria(tokenizer=Mock(), stop_words=["mock data"])
|
||||
# because we are mocking the tokenizer, we need to set the stop words manually
|
||||
stop_words_criteria.stop_words = stop_words_id
|
||||
|
||||
# this is the correct algorithm to check if the stop word tokens are present in a continuous and sequential order
|
||||
# For the input_ids1, the stop word tokens are present BUT not in a continuous order
|
||||
present_and_continuous = stop_words_criteria(input_ids1, scores=None)
|
||||
assert not present_and_continuous
|
||||
|
||||
# For the input_ids2, the stop word tokens are both present and in a continuous order
|
||||
present_and_continuous = stop_words_criteria(input_ids2, scores=None)
|
||||
assert present_and_continuous
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("stop_words", [["Berlin"], ["Berlin", "Brandenburg"], ["Berlin", "Brandenburg", "Germany"]])
|
||||
def test_stop_words_not_being_found(stop_words: List[str]):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user