mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 08:04:49 +00:00
fix: Prevent invalid answer from being selected in ExtractiveReader (#6460)
* Fix invalid answer being selected issue on ExtractiveReader * Rename variables to not shadow arguments
This commit is contained in:
parent
05a30c24aa
commit
c5342d1110
@ -206,62 +206,52 @@ class ExtractiveReader:
|
||||
implementations, it doesn't normalize the scores to make them easier to compare across different
|
||||
splits. Returns the top k answer spans.
|
||||
"""
|
||||
mask = sequence_ids == 1
|
||||
mask = torch.logical_and(mask, attention_mask == 1)
|
||||
start = torch.where(mask, start, -torch.inf)
|
||||
end = torch.where(mask, end, -torch.inf)
|
||||
mask = sequence_ids == 1 # Only keep tokens from the context (should ignore special tokens)
|
||||
mask = torch.logical_and(mask, attention_mask == 1) # Definitely remove special tokens
|
||||
start = torch.where(mask, start, -torch.inf) # Apply the mask on the start logits
|
||||
end = torch.where(mask, end, -torch.inf) # Apply the mask on the end logits
|
||||
start = start.unsqueeze(-1)
|
||||
end = end.unsqueeze(-2)
|
||||
|
||||
logits = start + end # shape: (batch_size, seq_length (start), seq_length (end))
|
||||
|
||||
# The mask here onwards is the same for all instances in the batch
|
||||
# As such we do away with the batch dimension
|
||||
mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=self.device)
|
||||
mask = torch.triu(mask) # End shouldn't be before start
|
||||
masked_logits = torch.where(mask, logits, -torch.inf)
|
||||
probabilities = torch.sigmoid(masked_logits * self.calibration_factor)
|
||||
|
||||
flat_probabilities = probabilities.flatten(-2, -1) # necessary for topk
|
||||
|
||||
# topk can return invalid candidates as well if answers_per_seq > num_valid_candidates
|
||||
# We only keep probability > 0 candidates later on
|
||||
candidates = torch.topk(flat_probabilities, answers_per_seq)
|
||||
seq_length = logits.shape[-1]
|
||||
start_candidates = candidates.indices // seq_length # Recover indices from flattening
|
||||
end_candidates = candidates.indices % seq_length
|
||||
candidates_values = candidates.values.cpu()
|
||||
start_candidates = start_candidates.cpu()
|
||||
end_candidates = end_candidates.cpu()
|
||||
|
||||
start_candidates_tokens_to_chars = [
|
||||
[encoding.token_to_chars(start) for start in candidates]
|
||||
for candidates, encoding in zip(start_candidates, encodings)
|
||||
]
|
||||
if missing_start_tokens := [
|
||||
(batch, index)
|
||||
for batch, token_to_chars in enumerate(start_candidates_tokens_to_chars)
|
||||
for index, pair in enumerate(token_to_chars)
|
||||
if pair is None
|
||||
]:
|
||||
logger.warning("Some start tokens could not be found in the context: %s", missing_start_tokens)
|
||||
start_candidates_char_indices = [
|
||||
[token_to_chars[0] if token_to_chars else None for token_to_chars in candidates]
|
||||
for candidates in start_candidates_tokens_to_chars
|
||||
]
|
||||
start_candidates_tokens_to_chars = []
|
||||
end_candidates_tokens_to_chars = []
|
||||
for i, (s_candidates, e_candidates, encoding) in enumerate(zip(start_candidates, end_candidates, encodings)):
|
||||
# Those with probabilities > 0 are valid
|
||||
valid = candidates_values[i] > 0
|
||||
s_char_spans = []
|
||||
e_char_spans = []
|
||||
for start_token, end_token in zip(s_candidates[valid], e_candidates[valid]):
|
||||
# token_to_chars returns `None` for special tokens
|
||||
# But we shouldn't have special tokens in the answers at this point
|
||||
# The whole span is given by the start of the start_token (index 0)
|
||||
# and the end of the end token (index 1)
|
||||
s_char_spans.append(encoding.token_to_chars(start_token)[0])
|
||||
e_char_spans.append(encoding.token_to_chars(end_token)[1])
|
||||
start_candidates_tokens_to_chars.append(s_char_spans)
|
||||
end_candidates_tokens_to_chars.append(e_char_spans)
|
||||
|
||||
end_candidates_tokens_to_chars = [
|
||||
[encoding.token_to_chars(end) for end in candidates]
|
||||
for candidates, encoding in zip(end_candidates, encodings)
|
||||
]
|
||||
if missing_end_tokens := [
|
||||
(batch, index)
|
||||
for batch, token_to_chars in enumerate(end_candidates_tokens_to_chars)
|
||||
for index, pair in enumerate(token_to_chars)
|
||||
if pair is None
|
||||
]:
|
||||
logger.warning("Some end tokens could not be found in the context: %s", missing_end_tokens)
|
||||
end_candidates_char_indices = [
|
||||
[token_to_chars[1] if token_to_chars else None for token_to_chars in candidates]
|
||||
for candidates in end_candidates_tokens_to_chars
|
||||
]
|
||||
|
||||
probabilities = candidates.values.cpu()
|
||||
|
||||
return start_candidates_char_indices, end_candidates_char_indices, probabilities
|
||||
return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, candidates_values
|
||||
|
||||
def _nest_answers(
|
||||
self,
|
||||
|
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Fix issue with `ExtractiveReader` picking invalid answers when `answers_per_seq` > num of valid answers
|
@ -261,78 +261,6 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer):
|
||||
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token")
|
||||
|
||||
|
||||
def test_missing_token_to_chars_values():
|
||||
# See https://github.com/deepset-ai/haystack/issues/6098
|
||||
|
||||
def mock_tokenize(
|
||||
texts: List[str],
|
||||
text_pairs: List[str],
|
||||
padding: bool,
|
||||
truncation: bool,
|
||||
max_length: int,
|
||||
return_tensors: str,
|
||||
return_overflowing_tokens: bool,
|
||||
stride: int,
|
||||
):
|
||||
assert padding
|
||||
assert truncation
|
||||
assert return_tensors == "pt"
|
||||
assert return_overflowing_tokens
|
||||
|
||||
tokens = Mock()
|
||||
|
||||
num_splits = [ceil(len(text + pair) / max_length) for text, pair in zip(texts, text_pairs)]
|
||||
tokens.overflow_to_sample_mapping = [i for i, num in enumerate(num_splits) for _ in range(num)]
|
||||
num_samples = sum(num_splits)
|
||||
tokens.encodings = [Mock() for _ in range(num_samples)]
|
||||
sequence_ids = [0] * 16 + [1] * 16 + [None] * (max_length - 32)
|
||||
for encoding in tokens.encodings:
|
||||
encoding.sequence_ids = sequence_ids
|
||||
encoding.token_to_chars = lambda i: None
|
||||
tokens.input_ids = torch.zeros(num_samples, max_length, dtype=torch.int)
|
||||
attention_mask = torch.zeros(num_samples, max_length, dtype=torch.int)
|
||||
attention_mask[:32] = 1
|
||||
tokens.attention_mask = attention_mask
|
||||
return tokens
|
||||
|
||||
class MockModel(torch.nn.Module):
|
||||
def to(self, device):
|
||||
assert device == "cpu:0"
|
||||
self.device_set = True
|
||||
return self
|
||||
|
||||
def forward(self, input_ids, attention_mask, *args, **kwargs):
|
||||
assert input_ids.device == torch.device("cpu")
|
||||
assert attention_mask.device == torch.device("cpu")
|
||||
assert self.device_set
|
||||
start = torch.zeros(input_ids.shape[:2])
|
||||
end = torch.zeros(input_ids.shape[:2])
|
||||
start[:, 27] = 1
|
||||
end[:, 31] = 1
|
||||
end[:, 32] = 1
|
||||
prediction = Mock()
|
||||
prediction.start_logits = start
|
||||
prediction.end_logits = end
|
||||
return prediction
|
||||
|
||||
with patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained") as tokenizer, patch(
|
||||
"haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained"
|
||||
) as model:
|
||||
tokenizer.return_value = mock_tokenize
|
||||
model.return_value = MockModel()
|
||||
reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0")
|
||||
reader.warm_up()
|
||||
|
||||
answers = reader.run(example_queries[0], example_documents[0], top_k=3)[
|
||||
"answers"
|
||||
] # [0] Uncomment and remove first two indices when batching support is reintroduced
|
||||
for doc, answer in zip(example_documents[0], answers[:3]):
|
||||
assert answer.start is None
|
||||
assert answer.end is None
|
||||
assert doc.content is not None
|
||||
assert answer.data == doc.content
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_t5():
|
||||
reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad")
|
||||
|
Loading…
x
Reference in New Issue
Block a user