mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 00:24:14 +00:00
fix: Prevent invalid answer from being selected in ExtractiveReader (#6460)
* Fix invalid answer being selected issue on ExtractiveReader * Rename variables to not shadow arguments
This commit is contained in:
parent
05a30c24aa
commit
c5342d1110
@ -206,62 +206,52 @@ class ExtractiveReader:
|
|||||||
implementations, it doesn't normalize the scores to make them easier to compare across different
|
implementations, it doesn't normalize the scores to make them easier to compare across different
|
||||||
splits. Returns the top k answer spans.
|
splits. Returns the top k answer spans.
|
||||||
"""
|
"""
|
||||||
mask = sequence_ids == 1
|
mask = sequence_ids == 1 # Only keep tokens from the context (should ignore special tokens)
|
||||||
mask = torch.logical_and(mask, attention_mask == 1)
|
mask = torch.logical_and(mask, attention_mask == 1) # Definitely remove special tokens
|
||||||
start = torch.where(mask, start, -torch.inf)
|
start = torch.where(mask, start, -torch.inf) # Apply the mask on the start logits
|
||||||
end = torch.where(mask, end, -torch.inf)
|
end = torch.where(mask, end, -torch.inf) # Apply the mask on the end logits
|
||||||
start = start.unsqueeze(-1)
|
start = start.unsqueeze(-1)
|
||||||
end = end.unsqueeze(-2)
|
end = end.unsqueeze(-2)
|
||||||
|
|
||||||
logits = start + end # shape: (batch_size, seq_length (start), seq_length (end))
|
logits = start + end # shape: (batch_size, seq_length (start), seq_length (end))
|
||||||
|
|
||||||
|
# The mask here onwards is the same for all instances in the batch
|
||||||
|
# As such we do away with the batch dimension
|
||||||
mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=self.device)
|
mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=self.device)
|
||||||
mask = torch.triu(mask) # End shouldn't be before start
|
mask = torch.triu(mask) # End shouldn't be before start
|
||||||
masked_logits = torch.where(mask, logits, -torch.inf)
|
masked_logits = torch.where(mask, logits, -torch.inf)
|
||||||
probabilities = torch.sigmoid(masked_logits * self.calibration_factor)
|
probabilities = torch.sigmoid(masked_logits * self.calibration_factor)
|
||||||
|
|
||||||
flat_probabilities = probabilities.flatten(-2, -1) # necessary for topk
|
flat_probabilities = probabilities.flatten(-2, -1) # necessary for topk
|
||||||
|
|
||||||
|
# topk can return invalid candidates as well if answers_per_seq > num_valid_candidates
|
||||||
|
# We only keep probability > 0 candidates later on
|
||||||
candidates = torch.topk(flat_probabilities, answers_per_seq)
|
candidates = torch.topk(flat_probabilities, answers_per_seq)
|
||||||
seq_length = logits.shape[-1]
|
seq_length = logits.shape[-1]
|
||||||
start_candidates = candidates.indices // seq_length # Recover indices from flattening
|
start_candidates = candidates.indices // seq_length # Recover indices from flattening
|
||||||
end_candidates = candidates.indices % seq_length
|
end_candidates = candidates.indices % seq_length
|
||||||
|
candidates_values = candidates.values.cpu()
|
||||||
start_candidates = start_candidates.cpu()
|
start_candidates = start_candidates.cpu()
|
||||||
end_candidates = end_candidates.cpu()
|
end_candidates = end_candidates.cpu()
|
||||||
|
|
||||||
start_candidates_tokens_to_chars = [
|
start_candidates_tokens_to_chars = []
|
||||||
[encoding.token_to_chars(start) for start in candidates]
|
end_candidates_tokens_to_chars = []
|
||||||
for candidates, encoding in zip(start_candidates, encodings)
|
for i, (s_candidates, e_candidates, encoding) in enumerate(zip(start_candidates, end_candidates, encodings)):
|
||||||
]
|
# Those with probabilities > 0 are valid
|
||||||
if missing_start_tokens := [
|
valid = candidates_values[i] > 0
|
||||||
(batch, index)
|
s_char_spans = []
|
||||||
for batch, token_to_chars in enumerate(start_candidates_tokens_to_chars)
|
e_char_spans = []
|
||||||
for index, pair in enumerate(token_to_chars)
|
for start_token, end_token in zip(s_candidates[valid], e_candidates[valid]):
|
||||||
if pair is None
|
# token_to_chars returns `None` for special tokens
|
||||||
]:
|
# But we shouldn't have special tokens in the answers at this point
|
||||||
logger.warning("Some start tokens could not be found in the context: %s", missing_start_tokens)
|
# The whole span is given by the start of the start_token (index 0)
|
||||||
start_candidates_char_indices = [
|
# and the end of the end token (index 1)
|
||||||
[token_to_chars[0] if token_to_chars else None for token_to_chars in candidates]
|
s_char_spans.append(encoding.token_to_chars(start_token)[0])
|
||||||
for candidates in start_candidates_tokens_to_chars
|
e_char_spans.append(encoding.token_to_chars(end_token)[1])
|
||||||
]
|
start_candidates_tokens_to_chars.append(s_char_spans)
|
||||||
|
end_candidates_tokens_to_chars.append(e_char_spans)
|
||||||
|
|
||||||
end_candidates_tokens_to_chars = [
|
return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, candidates_values
|
||||||
[encoding.token_to_chars(end) for end in candidates]
|
|
||||||
for candidates, encoding in zip(end_candidates, encodings)
|
|
||||||
]
|
|
||||||
if missing_end_tokens := [
|
|
||||||
(batch, index)
|
|
||||||
for batch, token_to_chars in enumerate(end_candidates_tokens_to_chars)
|
|
||||||
for index, pair in enumerate(token_to_chars)
|
|
||||||
if pair is None
|
|
||||||
]:
|
|
||||||
logger.warning("Some end tokens could not be found in the context: %s", missing_end_tokens)
|
|
||||||
end_candidates_char_indices = [
|
|
||||||
[token_to_chars[1] if token_to_chars else None for token_to_chars in candidates]
|
|
||||||
for candidates in end_candidates_tokens_to_chars
|
|
||||||
]
|
|
||||||
|
|
||||||
probabilities = candidates.values.cpu()
|
|
||||||
|
|
||||||
return start_candidates_char_indices, end_candidates_char_indices, probabilities
|
|
||||||
|
|
||||||
def _nest_answers(
|
def _nest_answers(
|
||||||
self,
|
self,
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
fixes:
|
||||||
|
- |
|
||||||
|
Fix issue with `ExtractiveReader` picking invalid answers when `answers_per_seq` > num of valid answers
|
@ -261,78 +261,6 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer):
|
|||||||
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token")
|
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token")
|
||||||
|
|
||||||
|
|
||||||
def test_missing_token_to_chars_values():
|
|
||||||
# See https://github.com/deepset-ai/haystack/issues/6098
|
|
||||||
|
|
||||||
def mock_tokenize(
|
|
||||||
texts: List[str],
|
|
||||||
text_pairs: List[str],
|
|
||||||
padding: bool,
|
|
||||||
truncation: bool,
|
|
||||||
max_length: int,
|
|
||||||
return_tensors: str,
|
|
||||||
return_overflowing_tokens: bool,
|
|
||||||
stride: int,
|
|
||||||
):
|
|
||||||
assert padding
|
|
||||||
assert truncation
|
|
||||||
assert return_tensors == "pt"
|
|
||||||
assert return_overflowing_tokens
|
|
||||||
|
|
||||||
tokens = Mock()
|
|
||||||
|
|
||||||
num_splits = [ceil(len(text + pair) / max_length) for text, pair in zip(texts, text_pairs)]
|
|
||||||
tokens.overflow_to_sample_mapping = [i for i, num in enumerate(num_splits) for _ in range(num)]
|
|
||||||
num_samples = sum(num_splits)
|
|
||||||
tokens.encodings = [Mock() for _ in range(num_samples)]
|
|
||||||
sequence_ids = [0] * 16 + [1] * 16 + [None] * (max_length - 32)
|
|
||||||
for encoding in tokens.encodings:
|
|
||||||
encoding.sequence_ids = sequence_ids
|
|
||||||
encoding.token_to_chars = lambda i: None
|
|
||||||
tokens.input_ids = torch.zeros(num_samples, max_length, dtype=torch.int)
|
|
||||||
attention_mask = torch.zeros(num_samples, max_length, dtype=torch.int)
|
|
||||||
attention_mask[:32] = 1
|
|
||||||
tokens.attention_mask = attention_mask
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
class MockModel(torch.nn.Module):
|
|
||||||
def to(self, device):
|
|
||||||
assert device == "cpu:0"
|
|
||||||
self.device_set = True
|
|
||||||
return self
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, *args, **kwargs):
|
|
||||||
assert input_ids.device == torch.device("cpu")
|
|
||||||
assert attention_mask.device == torch.device("cpu")
|
|
||||||
assert self.device_set
|
|
||||||
start = torch.zeros(input_ids.shape[:2])
|
|
||||||
end = torch.zeros(input_ids.shape[:2])
|
|
||||||
start[:, 27] = 1
|
|
||||||
end[:, 31] = 1
|
|
||||||
end[:, 32] = 1
|
|
||||||
prediction = Mock()
|
|
||||||
prediction.start_logits = start
|
|
||||||
prediction.end_logits = end
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
with patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained") as tokenizer, patch(
|
|
||||||
"haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained"
|
|
||||||
) as model:
|
|
||||||
tokenizer.return_value = mock_tokenize
|
|
||||||
model.return_value = MockModel()
|
|
||||||
reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0")
|
|
||||||
reader.warm_up()
|
|
||||||
|
|
||||||
answers = reader.run(example_queries[0], example_documents[0], top_k=3)[
|
|
||||||
"answers"
|
|
||||||
] # [0] Uncomment and remove first two indices when batching support is reintroduced
|
|
||||||
for doc, answer in zip(example_documents[0], answers[:3]):
|
|
||||||
assert answer.start is None
|
|
||||||
assert answer.end is None
|
|
||||||
assert doc.content is not None
|
|
||||||
assert answer.data == doc.content
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_t5():
|
def test_t5():
|
||||||
reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad")
|
reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user