feat: Extractive QA answer deduplication (#6459)

* Add answer deduplication

* Fix test

* Handle None case

* Release notes

* Handle cases where documents or answer spans could be None

* Adding checks for Nones and satisfying mypy

* Add option to turn off deduplication

* Adding unit tests

* Refactored tests to use fixtures

* Added overlap_threshold to run

* Update test

* Fixes related to the merge

* Remove casting, use direct variable names

* Move out if statement and add new test for it

* Update if statement to match comment

* Update how if statements work
This commit is contained in:
Sebastian Husch Lee 2023-12-18 19:27:04 +01:00 committed by GitHub
parent c294b8ac8c
commit dcf37c5173
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 355 additions and 20 deletions

View File

@ -49,6 +49,7 @@ class ExtractiveReader:
answers_per_seq: Optional[int] = None, answers_per_seq: Optional[int] = None,
no_answer: bool = True, no_answer: bool = True,
calibration_factor: float = 0.1, calibration_factor: float = 0.1,
overlap_threshold: Optional[float] = 0.01,
model_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
@ -75,6 +76,12 @@ class ExtractiveReader:
:param no_answer: Whether to return an additional `no answer` with an empty text and a score representing the :param no_answer: Whether to return an additional `no answer` with an empty text and a score representing the
probability that the other top_k answers are incorrect. probability that the other top_k answers are incorrect.
:param calibration_factor: Factor used for calibrating probabilities. :param calibration_factor: Factor used for calibrating probabilities.
:param overlap_threshold: If set this will remove duplicate answers if they have an overlap larger than the
supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
both of these answers could be kept if this variable is set to 0.24 or lower.
If None is provided then all answers are kept.
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained` :param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass, when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass,
see the model's documentation. see the model's documentation.
@ -93,6 +100,7 @@ class ExtractiveReader:
self.no_answer = no_answer self.no_answer = no_answer
self.calibration_factor = calibration_factor self.calibration_factor = calibration_factor
self.model_kwargs = model_kwargs or {} self.model_kwargs = model_kwargs or {}
self.overlap_threshold = overlap_threshold
def _get_telemetry_data(self) -> Dict[str, Any]: def _get_telemetry_data(self) -> Dict[str, Any]:
""" """
@ -258,6 +266,7 @@ class ExtractiveReader:
query_ids: List[int], query_ids: List[int],
document_ids: List[int], document_ids: List[int],
no_answer: bool, no_answer: bool,
overlap_threshold: Optional[float],
) -> List[List[ExtractedAnswer]]: ) -> List[List[ExtractedAnswer]]:
""" """
Reconstructs the nested structure that existed before flattening. Also computes a no answer score. Reconstructs the nested structure that existed before flattening. Also computes a no answer score.
@ -290,7 +299,8 @@ class ExtractiveReader:
answer.query = queries[query_id] answer.query = queries[query_id]
current_answers.append(answer) current_answers.append(answer)
i += 1 i += 1
current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True) current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True)
current_answers = self.deduplicate_by_overlap(current_answers, overlap_threshold=overlap_threshold)
current_answers = current_answers[:top_k] current_answers = current_answers[:top_k]
if no_answer: if no_answer:
no_answer_score = math.prod(1 - answer.score for answer in current_answers) no_answer_score = math.prod(1 - answer.score for answer in current_answers)
@ -298,13 +308,120 @@ class ExtractiveReader:
data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score
) )
current_answers.append(answer_) current_answers.append(answer_)
current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True) current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True)
if score_threshold is not None: if score_threshold is not None:
current_answers = [answer for answer in current_answers if answer.score >= score_threshold] current_answers = [answer for answer in current_answers if answer.score >= score_threshold]
nested_answers.append(current_answers) nested_answers.append(current_answers)
return nested_answers return nested_answers
def _calculate_overlap(self, answer1_start: int, answer1_end: int, answer2_start: int, answer2_end: int) -> int:
"""
Calculates the amount of overlap (in number of characters) between two answer offsets.
Stack overflow post explaining how to calculate overlap between two ranges:
https://stackoverflow.com/questions/325933/determine-whether-two-date-ranges-overlap/325964#325964
"""
# Check for overlap: (StartA <= EndB) and (StartB <= EndA)
if answer1_start <= answer2_end and answer2_start <= answer1_end:
return min(
answer1_end - answer1_start,
answer1_end - answer2_start,
answer2_end - answer1_start,
answer2_end - answer2_start,
)
return 0
def _should_keep(
self, candidate_answer: ExtractedAnswer, current_answers: List[ExtractedAnswer], overlap_threshold: float
) -> bool:
"""
Determine if the answer should be kept based on how much it overlaps with previous answers.
NOTE: We might want to avoid throwing away answers that only have a few character (or word) overlap:
- E.g. The answers "the river in" and "in Maine" from the context "I want to go to the river in Maine."
might both want to be kept.
:param candidate_answer: Candidate answer that will be checked if it should be kept.
:param current_answers: Current list of answers that will be kept.
:param overlap_threshold: If the overlap between two answers is greater than this threshold then return False.
"""
keep = True
# If the candidate answer doesn't have a document keep it
if not candidate_answer.document:
return keep
for ans in current_answers:
# If an answer in current_answers doesn't have a document skip the comparison
if not ans.document:
continue
# If offset is missing then keep both
if ans.document_offset is None:
continue
# If offset is missing then keep both
if candidate_answer.document_offset is None:
continue
# If the answers come from different documents then keep both
if candidate_answer.document.id != ans.document.id:
continue
overlap_len = self._calculate_overlap(
answer1_start=ans.document_offset.start,
answer1_end=ans.document_offset.end,
answer2_start=candidate_answer.document_offset.start,
answer2_end=candidate_answer.document_offset.end,
)
# If overlap is 0 then keep
if overlap_len == 0:
continue
overlap_frac_answer1 = overlap_len / (ans.document_offset.end - ans.document_offset.start)
overlap_frac_answer2 = overlap_len / (
candidate_answer.document_offset.end - candidate_answer.document_offset.start
)
if overlap_frac_answer1 > overlap_threshold or overlap_frac_answer2 > overlap_threshold:
keep = False
break
return keep
def deduplicate_by_overlap(
self, answers: List[ExtractedAnswer], overlap_threshold: Optional[float]
) -> List[ExtractedAnswer]:
"""
This de-duplicates overlapping Extractive Answers from the same document based on how much the spans of the
answers overlap.
:param answers: List of answers to be deduplicated.
:param overlap_threshold: If set this will remove duplicate answers if they have an overlap larger than the
supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
both of these answers could be kept if this variable is set to 0.24 or lower.
If None is provided then all answers are kept.
"""
if overlap_threshold is None:
return answers
# Initialize with the first answer and its offsets_in_document
deduplicated_answers = [answers[0]]
# Loop over remaining answers to check for overlaps
for ans in answers[1:]:
keep = self._should_keep(
candidate_answer=ans, current_answers=deduplicated_answers, overlap_threshold=overlap_threshold
)
if keep:
deduplicated_answers.append(ans)
return deduplicated_answers
@component.output_types(answers=List[ExtractedAnswer]) @component.output_types(answers=List[ExtractedAnswer])
def run( def run(
self, self,
@ -317,6 +434,7 @@ class ExtractiveReader:
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
answers_per_seq: Optional[int] = None, answers_per_seq: Optional[int] = None,
no_answer: Optional[bool] = None, no_answer: Optional[bool] = None,
overlap_threshold: Optional[float] = None,
): ):
""" """
Locates and extracts answers from the given Documents using the given query. Locates and extracts answers from the given Documents using the given query.
@ -325,8 +443,6 @@ class ExtractiveReader:
:param documents: List of Documents in which you want to search for an answer to the query. :param documents: List of Documents in which you want to search for an answer to the query.
:param top_k: The maximum number of answers to return. :param top_k: The maximum number of answers to return.
An additional answer is returned if no_answer is set to True (default). An additional answer is returned if no_answer is set to True (default).
:param score_threshold:
:return: List of ExtractedAnswers sorted by (desc.) answer score.
:param score_threshold: Returns only answers with the score above this threshold. :param score_threshold: Returns only answers with the score above this threshold.
:param max_seq_length: Maximum number of tokens. :param max_seq_length: Maximum number of tokens.
If a sequence exceeds it, the sequence is split. If a sequence exceeds it, the sequence is split.
@ -336,7 +452,17 @@ class ExtractiveReader:
:param max_batch_size: Maximum number of samples that are fed through the model at the same time. :param max_batch_size: Maximum number of samples that are fed through the model at the same time.
:param answers_per_seq: Number of answer candidates to consider per sequence. :param answers_per_seq: Number of answer candidates to consider per sequence.
This is relevant when a Document was split into multiple sequences because of max_seq_length. This is relevant when a Document was split into multiple sequences because of max_seq_length.
Default: 20
:param no_answer: Whether to return no answer scores. :param no_answer: Whether to return no answer scores.
Default: True
:param overlap_threshold: If set this will remove duplicate answers if they have an overlap larger than the
supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
both of these answers could be kept if this variable is set to 0.24 or lower.
If None is provided then all answers are kept.
Default: 0.01
:return: List of ExtractedAnswers sorted by (desc.) answer score.
""" """
queries = [query] # Temporary solution until we have decided what batching should look like in v2 queries = [query] # Temporary solution until we have decided what batching should look like in v2
nested_documents = [documents] nested_documents = [documents]
@ -348,8 +474,9 @@ class ExtractiveReader:
max_seq_length = max_seq_length or self.max_seq_length max_seq_length = max_seq_length or self.max_seq_length
stride = stride or self.stride stride = stride or self.stride
max_batch_size = max_batch_size or self.max_batch_size max_batch_size = max_batch_size or self.max_batch_size
answers_per_seq = answers_per_seq or self.answers_per_seq or top_k or 20 answers_per_seq = answers_per_seq or self.answers_per_seq or 20
no_answer = no_answer if no_answer is not None else self.no_answer no_answer = no_answer if no_answer is not None else self.no_answer
overlap_threshold = overlap_threshold or self.overlap_threshold
flattened_queries, flattened_documents, query_ids = self._flatten_documents(queries, nested_documents) flattened_queries, flattened_documents, query_ids = self._flatten_documents(queries, nested_documents)
input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess( input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess(
@ -385,17 +512,18 @@ class ExtractiveReader:
) )
answers = self._nest_answers( answers = self._nest_answers(
start, start=start,
end, end=end,
probabilities, probabilities=probabilities,
flattened_documents, flattened_documents=flattened_documents,
queries, queries=queries,
answers_per_seq, answers_per_seq=answers_per_seq,
top_k, top_k=top_k,
score_threshold, score_threshold=score_threshold,
query_ids, query_ids=query_ids,
document_ids, document_ids=document_ids,
no_answer, no_answer=no_answer,
overlap_threshold=overlap_threshold,
) )
return {"answers": answers[0]} # same temporary batching fix as above return {"answers": answers[0]} # same temporary batching fix as above

View File

@ -0,0 +1,4 @@
---
features:
- |
Introduces answer deduplication on the Document level based on an overlap threshold.

View File

@ -7,7 +7,7 @@ import torch
from transformers import pipeline from transformers import pipeline
from haystack.components.readers import ExtractiveReader from haystack.components.readers import ExtractiveReader
from haystack import Document from haystack import Document, ExtractedAnswer
@pytest.fixture @pytest.fixture
@ -233,8 +233,19 @@ def test_nest_answers(mock_reader: ExtractiveReader):
probabilities = torch.arange(5).unsqueeze(0) / 5 + torch.arange(6).unsqueeze(-1) / 25 probabilities = torch.arange(5).unsqueeze(0) / 5 + torch.arange(6).unsqueeze(-1) / 25
query_ids = [0] * 3 + [1] * 3 query_ids = [0] * 3 + [1] * 3
document_ids = list(range(3)) * 2 document_ids = list(range(3)) * 2
nested_answers = mock_reader._nest_answers( nested_answers = mock_reader._nest_answers( # type: ignore
start, end, probabilities, example_documents[0], example_queries, 5, 3, None, query_ids, document_ids, True # type: ignore start=start,
end=end,
probabilities=probabilities,
flattened_documents=example_documents[0],
queries=example_queries,
answers_per_seq=5,
top_k=3,
score_threshold=None,
query_ids=query_ids,
document_ids=document_ids,
no_answer=True,
overlap_threshold=None,
) )
expected_no_answers = [0.2 * 0.16 * 0.12, 0] expected_no_answers = [0.2 * 0.16 * 0.12, 0]
for query, answers, expected_no_answer, probabilities in zip( for query, answers, expected_no_answer, probabilities in zip(
@ -261,6 +272,196 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer):
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token") mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token")
class TestDeduplication:
@pytest.fixture
def doc1(self):
return Document(content="I want to go to the river in Maine.")
@pytest.fixture
def doc2(self):
return Document(content="I want to go skiing in Colorado.")
@pytest.fixture
def candidate_answer(self, doc1):
answer1 = "the river"
return ExtractedAnswer(
query="test",
data=answer1,
document=doc1,
document_offset=ExtractedAnswer.Span(doc1.content.find(answer1), doc1.content.find(answer1) + len(answer1)),
score=0.1,
meta={},
)
def test_calculate_overlap(self, mock_reader: ExtractiveReader, doc1: Document):
answer1 = "the river"
answer2 = "river in Maine"
overlap_in_characters = mock_reader._calculate_overlap(
answer1_start=doc1.content.find(answer1),
answer1_end=doc1.content.find(answer1) + len(answer1),
answer2_start=doc1.content.find(answer2),
answer2_end=doc1.content.find(answer2) + len(answer2),
)
assert overlap_in_characters == 5
def test_should_keep_false(
self, mock_reader: ExtractiveReader, doc1: Document, doc2: Document, candidate_answer: ExtractedAnswer
):
answer2 = "river in Maine"
answer3 = "skiing in Colorado"
keep = mock_reader._should_keep(
candidate_answer=candidate_answer,
current_answers=[
ExtractedAnswer(
query="test",
data=answer2,
document=doc1,
document_offset=ExtractedAnswer.Span(
doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)
),
score=0.1,
meta={},
),
ExtractedAnswer(
query="test",
data=answer3,
document=doc2,
document_offset=ExtractedAnswer.Span(
doc2.content.find(answer3), doc2.content.find(answer3) + len(answer3)
),
score=0.1,
meta={},
),
],
overlap_threshold=0.01,
)
assert keep is False
def test_should_keep_true(
self, mock_reader: ExtractiveReader, doc1: Document, doc2: Document, candidate_answer: ExtractedAnswer
):
answer2 = "Maine"
answer3 = "skiing in Colorado"
keep = mock_reader._should_keep(
candidate_answer=candidate_answer,
current_answers=[
ExtractedAnswer(
query="test",
data=answer2,
document=doc1,
document_offset=ExtractedAnswer.Span(
doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)
),
score=0.1,
meta={},
),
ExtractedAnswer(
query="test",
data=answer3,
document=doc2,
document_offset=ExtractedAnswer.Span(
doc2.content.find(answer3), doc2.content.find(answer3) + len(answer3)
),
score=0.1,
meta={},
),
],
overlap_threshold=0.01,
)
assert keep is True
def test_should_keep_missing_document_current_answer(
self, mock_reader: ExtractiveReader, doc1: Document, candidate_answer: ExtractedAnswer
):
answer2 = "river in Maine"
keep = mock_reader._should_keep(
candidate_answer=candidate_answer,
current_answers=[
ExtractedAnswer(
query="test",
data=answer2,
document=None,
document_offset=ExtractedAnswer.Span(
doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)
),
score=0.1,
meta={},
)
],
overlap_threshold=0.01,
)
assert keep is True
def test_should_keep_missing_document_candidate_answer(
self, mock_reader: ExtractiveReader, doc1: Document, candidate_answer: ExtractedAnswer
):
answer2 = "river in Maine"
keep = mock_reader._should_keep(
candidate_answer=ExtractedAnswer(
query="test",
data=answer2,
document=None,
document_offset=ExtractedAnswer.Span(
doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)
),
score=0.1,
meta={},
),
current_answers=[
ExtractedAnswer(
query="test",
data=answer2,
document=doc1,
document_offset=ExtractedAnswer.Span(
doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)
),
score=0.1,
meta={},
)
],
overlap_threshold=0.01,
)
assert keep is True
def test_should_keep_missing_span(
self, mock_reader: ExtractiveReader, doc1: Document, candidate_answer: ExtractedAnswer
):
answer2 = "river in Maine"
keep = mock_reader._should_keep(
candidate_answer=candidate_answer,
current_answers=[
ExtractedAnswer(query="test", data=answer2, document=doc1, document_offset=None, score=0.1, meta={})
],
overlap_threshold=0.01,
)
assert keep is True
def test_deduplicate_by_overlap_none_overlap(
self, mock_reader: ExtractiveReader, candidate_answer: ExtractedAnswer
):
result = mock_reader.deduplicate_by_overlap(
answers=[candidate_answer, candidate_answer], overlap_threshold=None
)
assert len(result) == 2
def test_deduplicate_by_overlap(
self, mock_reader: ExtractiveReader, candidate_answer: ExtractedAnswer, doc1: Document
):
answer2 = "Maine"
extracted_answer2 = ExtractedAnswer(
query="test",
data=answer2,
document=doc1,
document_offset=ExtractedAnswer.Span(doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)),
score=0.1,
meta={},
)
result = mock_reader.deduplicate_by_overlap(
answers=[candidate_answer, candidate_answer, extracted_answer2], overlap_threshold=0.01
)
assert len(result) == 2
@pytest.mark.integration @pytest.mark.integration
def test_t5(): def test_t5():
reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad") reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad")
@ -274,6 +475,7 @@ def test_t5():
assert answers[1].score == pytest.approx(0.7703777551651001) assert answers[1].score == pytest.approx(0.7703777551651001)
assert answers[2].data is None assert answers[2].data is None
assert answers[2].score == pytest.approx(0.051331606147570596) assert answers[2].score == pytest.approx(0.051331606147570596)
assert len(answers) == 3
# Uncomment assertions below when batching is reintroduced # Uncomment assertions below when batching is reintroduced
# assert answers[0][2].score == pytest.approx(0.051331606147570596) # assert answers[0][2].score == pytest.approx(0.051331606147570596)
# assert answers[1][0].data == "Jerry" # assert answers[1][0].data == "Jerry"
@ -297,6 +499,7 @@ def test_roberta():
assert answers[1].score == pytest.approx(0.857952892780304) assert answers[1].score == pytest.approx(0.857952892780304)
assert answers[2].data is None assert answers[2].data is None
assert answers[2].score == pytest.approx(0.019673851661650588) assert answers[2].score == pytest.approx(0.019673851661650588)
assert len(answers) == 3
# uncomment assertions below when there is batching in v2 # uncomment assertions below when there is batching in v2
# assert answers[0][0].data == "Olaf Scholz" # assert answers[0][0].data == "Olaf Scholz"
# assert answers[0][0].score == pytest.approx(0.8614975214004517) # assert answers[0][0].score == pytest.approx(0.8614975214004517)
@ -314,7 +517,7 @@ def test_roberta():
@pytest.mark.integration @pytest.mark.integration
def test_matches_hf_pipeline(): def test_matches_hf_pipeline():
reader = ExtractiveReader("deepset/tinyroberta-squad2", device="cpu") reader = ExtractiveReader("deepset/tinyroberta-squad2", device="cpu", overlap_threshold=None)
reader.warm_up() reader.warm_up()
answers = reader.run(example_queries[0], [[example_documents[0][0]]][0], top_k=20, no_answer=False)[ answers = reader.run(example_queries[0], [[example_documents[0][0]]][0], top_k=20, no_answer=False)[
"answers" "answers"