diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index aaeebc889..57ef37ef6 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -51,8 +51,18 @@ with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_a self.stop_words = encoded_stop_words.input_ids.to(device) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - stop_result = torch.isin(self.stop_words, input_ids[-1]) - return any(all(stop_word) for stop_word in stop_result) + for stop_word in self.stop_words: + found_stop_word = self.is_stop_word_found(input_ids, stop_word) + if found_stop_word: + return True + return False + + def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_word: torch.Tensor) -> bool: + generated_text_ids = generated_text_ids[-1] + len_generated_text_ids = generated_text_ids.size(0) + len_stop_word = stop_word.size(0) + result = all(generated_text_ids[len_generated_text_ids - len_stop_word :].eq(stop_word)) + return result class HFLocalInvocationLayer(PromptModelInvocationLayer): diff --git a/releasenotes/notes/fix-stop-words-criteria-check-order-bug-4badfcc021dfc92a.yaml b/releasenotes/notes/fix-stop-words-criteria-check-order-bug-4badfcc021dfc92a.yaml new file mode 100644 index 000000000..6f06db3c9 --- /dev/null +++ b/releasenotes/notes/fix-stop-words-criteria-check-order-bug-4badfcc021dfc92a.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fix StopWordsCriteria not checking stop word tokens in a continuous and sequential order diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index b017478e3..9aeef2508 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -468,6 +468,44 @@ def test_stop_words_multiple_token(stop_words: List[str]): assert "health" not in result[0] +@pytest.mark.unit +def test_stop_words_criteria(): + """ + Test that StopWordsCriteria will check stop word tokens in a continuous and sequential order + """ + # input ids for "unambiguously" + stop_words_id = torch.tensor([[73, 24621, 11937]]) + + # input ids for "This is ambiguously, but is unrelated." + input_ids1 = torch.tensor([[100, 19, 24621, 11937, 6, 68, 19, 73, 3897, 5]]) + # input ids for "This is unambiguously" + input_ids2 = torch.tensor([[100, 19, 73, 24621, 11937]]) + + # We used to implement stop words algorithm using the torch.isin function like this: + # `all(torch.isin(stop_words_id, input_ids1)[0])` + # However, this algorithm is not correct as it will return True for presence of "unambiguously" in input_ids1 + # and True for presence of "unambiguously" in input_ids2. This is because the algorithm will check + # if the stop word tokens are present in the input_ids, but it does not check if the stop word tokens are + # present in a continuous/sequential order. + + # In "This is ambiguously, but is unrelated." sentence the "un" token comes from "unrelated" and the + # "ambiguously" token comes from "ambiguously". The algorithm will return True for presence of + # "unambiguously" in input_ids1 which is not correct. + + stop_words_criteria = StopWordsCriteria(tokenizer=Mock(), stop_words=["mock data"]) + # because we are mocking the tokenizer, we need to set the stop words manually + stop_words_criteria.stop_words = stop_words_id + + # this is the correct algorithm to check if the stop word tokens are present in a continuous and sequential order + # For the input_ids1, the stop word tokens are present BUT not in a continuous order + present_and_continuous = stop_words_criteria(input_ids1, scores=None) + assert not present_and_continuous + + # For the input_ids2, the stop word tokens are both present and in a continuous order + present_and_continuous = stop_words_criteria(input_ids2, scores=None) + assert present_and_continuous + + @pytest.mark.integration @pytest.mark.parametrize("stop_words", [["Berlin"], ["Berlin", "Brandenburg"], ["Berlin", "Brandenburg", "Germany"]]) def test_stop_words_not_being_found(stop_words: List[str]):