From c5342d1110554258ca81b75c4e775e9e01ea785e Mon Sep 17 00:00:00 2001 From: Bijay Gurung <4636315+bglearning@users.noreply.github.com> Date: Wed, 6 Dec 2023 09:49:02 +0100 Subject: [PATCH] fix: Prevent invalid answer from being selected in ExtractiveReader (#6460) * Fix invalid answer being selected issue on ExtractiveReader * Rename variables to not shadow arguments --- haystack/components/readers/extractive.py | 66 ++++++++--------- ...e-reader-invalid-ans-a88e6b1d1ee897aa.yaml | 4 ++ test/components/readers/test_extractive.py | 72 ------------------- 3 files changed, 32 insertions(+), 110 deletions(-) create mode 100644 releasenotes/notes/extractive-reader-invalid-ans-a88e6b1d1ee897aa.yaml diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index 1fd5e38d0..86291c415 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -206,62 +206,52 @@ class ExtractiveReader: implementations, it doesn't normalize the scores to make them easier to compare across different splits. Returns the top k answer spans. """ - mask = sequence_ids == 1 - mask = torch.logical_and(mask, attention_mask == 1) - start = torch.where(mask, start, -torch.inf) - end = torch.where(mask, end, -torch.inf) + mask = sequence_ids == 1 # Only keep tokens from the context (should ignore special tokens) + mask = torch.logical_and(mask, attention_mask == 1) # Definitely remove special tokens + start = torch.where(mask, start, -torch.inf) # Apply the mask on the start logits + end = torch.where(mask, end, -torch.inf) # Apply the mask on the end logits start = start.unsqueeze(-1) end = end.unsqueeze(-2) logits = start + end # shape: (batch_size, seq_length (start), seq_length (end)) + + # The mask here onwards is the same for all instances in the batch + # As such we do away with the batch dimension mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=self.device) mask = torch.triu(mask) # End shouldn't be before start masked_logits = torch.where(mask, logits, -torch.inf) probabilities = torch.sigmoid(masked_logits * self.calibration_factor) flat_probabilities = probabilities.flatten(-2, -1) # necessary for topk + + # topk can return invalid candidates as well if answers_per_seq > num_valid_candidates + # We only keep probability > 0 candidates later on candidates = torch.topk(flat_probabilities, answers_per_seq) seq_length = logits.shape[-1] start_candidates = candidates.indices // seq_length # Recover indices from flattening end_candidates = candidates.indices % seq_length + candidates_values = candidates.values.cpu() start_candidates = start_candidates.cpu() end_candidates = end_candidates.cpu() - start_candidates_tokens_to_chars = [ - [encoding.token_to_chars(start) for start in candidates] - for candidates, encoding in zip(start_candidates, encodings) - ] - if missing_start_tokens := [ - (batch, index) - for batch, token_to_chars in enumerate(start_candidates_tokens_to_chars) - for index, pair in enumerate(token_to_chars) - if pair is None - ]: - logger.warning("Some start tokens could not be found in the context: %s", missing_start_tokens) - start_candidates_char_indices = [ - [token_to_chars[0] if token_to_chars else None for token_to_chars in candidates] - for candidates in start_candidates_tokens_to_chars - ] + start_candidates_tokens_to_chars = [] + end_candidates_tokens_to_chars = [] + for i, (s_candidates, e_candidates, encoding) in enumerate(zip(start_candidates, end_candidates, encodings)): + # Those with probabilities > 0 are valid + valid = candidates_values[i] > 0 + s_char_spans = [] + e_char_spans = [] + for start_token, end_token in zip(s_candidates[valid], e_candidates[valid]): + # token_to_chars returns `None` for special tokens + # But we shouldn't have special tokens in the answers at this point + # The whole span is given by the start of the start_token (index 0) + # and the end of the end token (index 1) + s_char_spans.append(encoding.token_to_chars(start_token)[0]) + e_char_spans.append(encoding.token_to_chars(end_token)[1]) + start_candidates_tokens_to_chars.append(s_char_spans) + end_candidates_tokens_to_chars.append(e_char_spans) - end_candidates_tokens_to_chars = [ - [encoding.token_to_chars(end) for end in candidates] - for candidates, encoding in zip(end_candidates, encodings) - ] - if missing_end_tokens := [ - (batch, index) - for batch, token_to_chars in enumerate(end_candidates_tokens_to_chars) - for index, pair in enumerate(token_to_chars) - if pair is None - ]: - logger.warning("Some end tokens could not be found in the context: %s", missing_end_tokens) - end_candidates_char_indices = [ - [token_to_chars[1] if token_to_chars else None for token_to_chars in candidates] - for candidates in end_candidates_tokens_to_chars - ] - - probabilities = candidates.values.cpu() - - return start_candidates_char_indices, end_candidates_char_indices, probabilities + return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, candidates_values def _nest_answers( self, diff --git a/releasenotes/notes/extractive-reader-invalid-ans-a88e6b1d1ee897aa.yaml b/releasenotes/notes/extractive-reader-invalid-ans-a88e6b1d1ee897aa.yaml new file mode 100644 index 000000000..edec36fc4 --- /dev/null +++ b/releasenotes/notes/extractive-reader-invalid-ans-a88e6b1d1ee897aa.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fix issue with `ExtractiveReader` picking invalid answers when `answers_per_seq` > num of valid answers diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index a61cbe36b..a48c984c4 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -261,78 +261,6 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer): mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token") -def test_missing_token_to_chars_values(): - # See https://github.com/deepset-ai/haystack/issues/6098 - - def mock_tokenize( - texts: List[str], - text_pairs: List[str], - padding: bool, - truncation: bool, - max_length: int, - return_tensors: str, - return_overflowing_tokens: bool, - stride: int, - ): - assert padding - assert truncation - assert return_tensors == "pt" - assert return_overflowing_tokens - - tokens = Mock() - - num_splits = [ceil(len(text + pair) / max_length) for text, pair in zip(texts, text_pairs)] - tokens.overflow_to_sample_mapping = [i for i, num in enumerate(num_splits) for _ in range(num)] - num_samples = sum(num_splits) - tokens.encodings = [Mock() for _ in range(num_samples)] - sequence_ids = [0] * 16 + [1] * 16 + [None] * (max_length - 32) - for encoding in tokens.encodings: - encoding.sequence_ids = sequence_ids - encoding.token_to_chars = lambda i: None - tokens.input_ids = torch.zeros(num_samples, max_length, dtype=torch.int) - attention_mask = torch.zeros(num_samples, max_length, dtype=torch.int) - attention_mask[:32] = 1 - tokens.attention_mask = attention_mask - return tokens - - class MockModel(torch.nn.Module): - def to(self, device): - assert device == "cpu:0" - self.device_set = True - return self - - def forward(self, input_ids, attention_mask, *args, **kwargs): - assert input_ids.device == torch.device("cpu") - assert attention_mask.device == torch.device("cpu") - assert self.device_set - start = torch.zeros(input_ids.shape[:2]) - end = torch.zeros(input_ids.shape[:2]) - start[:, 27] = 1 - end[:, 31] = 1 - end[:, 32] = 1 - prediction = Mock() - prediction.start_logits = start - prediction.end_logits = end - return prediction - - with patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained") as tokenizer, patch( - "haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained" - ) as model: - tokenizer.return_value = mock_tokenize - model.return_value = MockModel() - reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0") - reader.warm_up() - - answers = reader.run(example_queries[0], example_documents[0], top_k=3)[ - "answers" - ] # [0] Uncomment and remove first two indices when batching support is reintroduced - for doc, answer in zip(example_documents[0], answers[:3]): - assert answer.start is None - assert answer.end is None - assert doc.content is not None - assert answer.data == doc.content - - @pytest.mark.integration def test_t5(): reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad")