fix: Prevent invalid answer from being selected in ExtractiveReader (#6460)

* Fix invalid answer being selected issue on ExtractiveReader

* Rename variables to not shadow arguments
This commit is contained in:
Bijay Gurung 2023-12-06 09:49:02 +01:00 committed by GitHub
parent 05a30c24aa
commit c5342d1110
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 110 deletions

View File

@ -206,62 +206,52 @@ class ExtractiveReader:
implementations, it doesn't normalize the scores to make them easier to compare across different
splits. Returns the top k answer spans.
"""
mask = sequence_ids == 1
mask = torch.logical_and(mask, attention_mask == 1)
start = torch.where(mask, start, -torch.inf)
end = torch.where(mask, end, -torch.inf)
mask = sequence_ids == 1 # Only keep tokens from the context (should ignore special tokens)
mask = torch.logical_and(mask, attention_mask == 1) # Definitely remove special tokens
start = torch.where(mask, start, -torch.inf) # Apply the mask on the start logits
end = torch.where(mask, end, -torch.inf) # Apply the mask on the end logits
start = start.unsqueeze(-1)
end = end.unsqueeze(-2)
logits = start + end # shape: (batch_size, seq_length (start), seq_length (end))
# The mask here onwards is the same for all instances in the batch
# As such we do away with the batch dimension
mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=self.device)
mask = torch.triu(mask) # End shouldn't be before start
masked_logits = torch.where(mask, logits, -torch.inf)
probabilities = torch.sigmoid(masked_logits * self.calibration_factor)
flat_probabilities = probabilities.flatten(-2, -1) # necessary for topk
# topk can return invalid candidates as well if answers_per_seq > num_valid_candidates
# We only keep probability > 0 candidates later on
candidates = torch.topk(flat_probabilities, answers_per_seq)
seq_length = logits.shape[-1]
start_candidates = candidates.indices // seq_length # Recover indices from flattening
end_candidates = candidates.indices % seq_length
candidates_values = candidates.values.cpu()
start_candidates = start_candidates.cpu()
end_candidates = end_candidates.cpu()
start_candidates_tokens_to_chars = [
[encoding.token_to_chars(start) for start in candidates]
for candidates, encoding in zip(start_candidates, encodings)
]
if missing_start_tokens := [
(batch, index)
for batch, token_to_chars in enumerate(start_candidates_tokens_to_chars)
for index, pair in enumerate(token_to_chars)
if pair is None
]:
logger.warning("Some start tokens could not be found in the context: %s", missing_start_tokens)
start_candidates_char_indices = [
[token_to_chars[0] if token_to_chars else None for token_to_chars in candidates]
for candidates in start_candidates_tokens_to_chars
]
start_candidates_tokens_to_chars = []
end_candidates_tokens_to_chars = []
for i, (s_candidates, e_candidates, encoding) in enumerate(zip(start_candidates, end_candidates, encodings)):
# Those with probabilities > 0 are valid
valid = candidates_values[i] > 0
s_char_spans = []
e_char_spans = []
for start_token, end_token in zip(s_candidates[valid], e_candidates[valid]):
# token_to_chars returns `None` for special tokens
# But we shouldn't have special tokens in the answers at this point
# The whole span is given by the start of the start_token (index 0)
# and the end of the end token (index 1)
s_char_spans.append(encoding.token_to_chars(start_token)[0])
e_char_spans.append(encoding.token_to_chars(end_token)[1])
start_candidates_tokens_to_chars.append(s_char_spans)
end_candidates_tokens_to_chars.append(e_char_spans)
end_candidates_tokens_to_chars = [
[encoding.token_to_chars(end) for end in candidates]
for candidates, encoding in zip(end_candidates, encodings)
]
if missing_end_tokens := [
(batch, index)
for batch, token_to_chars in enumerate(end_candidates_tokens_to_chars)
for index, pair in enumerate(token_to_chars)
if pair is None
]:
logger.warning("Some end tokens could not be found in the context: %s", missing_end_tokens)
end_candidates_char_indices = [
[token_to_chars[1] if token_to_chars else None for token_to_chars in candidates]
for candidates in end_candidates_tokens_to_chars
]
probabilities = candidates.values.cpu()
return start_candidates_char_indices, end_candidates_char_indices, probabilities
return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, candidates_values
def _nest_answers(
self,

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fix issue with `ExtractiveReader` picking invalid answers when `answers_per_seq` > num of valid answers

View File

@ -261,78 +261,6 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer):
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token")
def test_missing_token_to_chars_values():
# See https://github.com/deepset-ai/haystack/issues/6098
def mock_tokenize(
texts: List[str],
text_pairs: List[str],
padding: bool,
truncation: bool,
max_length: int,
return_tensors: str,
return_overflowing_tokens: bool,
stride: int,
):
assert padding
assert truncation
assert return_tensors == "pt"
assert return_overflowing_tokens
tokens = Mock()
num_splits = [ceil(len(text + pair) / max_length) for text, pair in zip(texts, text_pairs)]
tokens.overflow_to_sample_mapping = [i for i, num in enumerate(num_splits) for _ in range(num)]
num_samples = sum(num_splits)
tokens.encodings = [Mock() for _ in range(num_samples)]
sequence_ids = [0] * 16 + [1] * 16 + [None] * (max_length - 32)
for encoding in tokens.encodings:
encoding.sequence_ids = sequence_ids
encoding.token_to_chars = lambda i: None
tokens.input_ids = torch.zeros(num_samples, max_length, dtype=torch.int)
attention_mask = torch.zeros(num_samples, max_length, dtype=torch.int)
attention_mask[:32] = 1
tokens.attention_mask = attention_mask
return tokens
class MockModel(torch.nn.Module):
def to(self, device):
assert device == "cpu:0"
self.device_set = True
return self
def forward(self, input_ids, attention_mask, *args, **kwargs):
assert input_ids.device == torch.device("cpu")
assert attention_mask.device == torch.device("cpu")
assert self.device_set
start = torch.zeros(input_ids.shape[:2])
end = torch.zeros(input_ids.shape[:2])
start[:, 27] = 1
end[:, 31] = 1
end[:, 32] = 1
prediction = Mock()
prediction.start_logits = start
prediction.end_logits = end
return prediction
with patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained") as tokenizer, patch(
"haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained"
) as model:
tokenizer.return_value = mock_tokenize
model.return_value = MockModel()
reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0")
reader.warm_up()
answers = reader.run(example_queries[0], example_documents[0], top_k=3)[
"answers"
] # [0] Uncomment and remove first two indices when batching support is reintroduced
for doc, answer in zip(example_documents[0], answers[:3]):
assert answer.start is None
assert answer.end is None
assert doc.content is not None
assert answer.data == doc.content
@pytest.mark.integration
def test_t5():
reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad")