From dcf37c517359b2235b5524bde68ca662bdcf9c8f Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 18 Dec 2023 19:27:04 +0100 Subject: [PATCH] feat: Extractive QA answer deduplication (#6459) * Add answer deduplication * Fix test * Handle None case * Release notes * Handle cases where documents or answer spans could be None * Adding checks for Nones and satisfying mypy * Add option to turn off deduplication * Adding unit tests * Refactored tests to use fixtures * Added overlap_threshold to run * Update test * Fixes related to the merge * Remove casting, use direct variable names * Move out if statement and add new test for it * Update if statement to match comment * Update how if statements work --- haystack/components/readers/extractive.py | 160 +++++++++++-- ...tive-qa-answer-dedup-7ca3b94b79b38854.yaml | 4 + test/components/readers/test_extractive.py | 211 +++++++++++++++++- 3 files changed, 355 insertions(+), 20 deletions(-) create mode 100644 releasenotes/notes/extractive-qa-answer-dedup-7ca3b94b79b38854.yaml diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index 5eaa42839..dea016f77 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -49,6 +49,7 @@ class ExtractiveReader: answers_per_seq: Optional[int] = None, no_answer: bool = True, calibration_factor: float = 0.1, + overlap_threshold: Optional[float] = 0.01, model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ @@ -75,6 +76,12 @@ class ExtractiveReader: :param no_answer: Whether to return an additional `no answer` with an empty text and a score representing the probability that the other top_k answers are incorrect. :param calibration_factor: Factor used for calibrating probabilities. + :param overlap_threshold: If set this will remove duplicate answers if they have an overlap larger than the + supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove + one of these answers since the second answer has a 100% (1.0) overlap with the first answer. + However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so + both of these answers could be kept if this variable is set to 0.24 or lower. + If None is provided then all answers are kept. :param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained` when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass, see the model's documentation. @@ -93,6 +100,7 @@ class ExtractiveReader: self.no_answer = no_answer self.calibration_factor = calibration_factor self.model_kwargs = model_kwargs or {} + self.overlap_threshold = overlap_threshold def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -258,6 +266,7 @@ class ExtractiveReader: query_ids: List[int], document_ids: List[int], no_answer: bool, + overlap_threshold: Optional[float], ) -> List[List[ExtractedAnswer]]: """ Reconstructs the nested structure that existed before flattening. Also computes a no answer score. @@ -290,7 +299,8 @@ class ExtractiveReader: answer.query = queries[query_id] current_answers.append(answer) i += 1 - current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True) + current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True) + current_answers = self.deduplicate_by_overlap(current_answers, overlap_threshold=overlap_threshold) current_answers = current_answers[:top_k] if no_answer: no_answer_score = math.prod(1 - answer.score for answer in current_answers) @@ -298,13 +308,120 @@ class ExtractiveReader: data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score ) current_answers.append(answer_) - current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True) + current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True) if score_threshold is not None: current_answers = [answer for answer in current_answers if answer.score >= score_threshold] nested_answers.append(current_answers) return nested_answers + def _calculate_overlap(self, answer1_start: int, answer1_end: int, answer2_start: int, answer2_end: int) -> int: + """ + Calculates the amount of overlap (in number of characters) between two answer offsets. + + Stack overflow post explaining how to calculate overlap between two ranges: + https://stackoverflow.com/questions/325933/determine-whether-two-date-ranges-overlap/325964#325964 + """ + # Check for overlap: (StartA <= EndB) and (StartB <= EndA) + if answer1_start <= answer2_end and answer2_start <= answer1_end: + return min( + answer1_end - answer1_start, + answer1_end - answer2_start, + answer2_end - answer1_start, + answer2_end - answer2_start, + ) + return 0 + + def _should_keep( + self, candidate_answer: ExtractedAnswer, current_answers: List[ExtractedAnswer], overlap_threshold: float + ) -> bool: + """ + Determine if the answer should be kept based on how much it overlaps with previous answers. + + NOTE: We might want to avoid throwing away answers that only have a few character (or word) overlap: + - E.g. The answers "the river in" and "in Maine" from the context "I want to go to the river in Maine." + might both want to be kept. + + :param candidate_answer: Candidate answer that will be checked if it should be kept. + :param current_answers: Current list of answers that will be kept. + :param overlap_threshold: If the overlap between two answers is greater than this threshold then return False. + """ + keep = True + + # If the candidate answer doesn't have a document keep it + if not candidate_answer.document: + return keep + + for ans in current_answers: + # If an answer in current_answers doesn't have a document skip the comparison + if not ans.document: + continue + + # If offset is missing then keep both + if ans.document_offset is None: + continue + + # If offset is missing then keep both + if candidate_answer.document_offset is None: + continue + + # If the answers come from different documents then keep both + if candidate_answer.document.id != ans.document.id: + continue + + overlap_len = self._calculate_overlap( + answer1_start=ans.document_offset.start, + answer1_end=ans.document_offset.end, + answer2_start=candidate_answer.document_offset.start, + answer2_end=candidate_answer.document_offset.end, + ) + + # If overlap is 0 then keep + if overlap_len == 0: + continue + + overlap_frac_answer1 = overlap_len / (ans.document_offset.end - ans.document_offset.start) + overlap_frac_answer2 = overlap_len / ( + candidate_answer.document_offset.end - candidate_answer.document_offset.start + ) + + if overlap_frac_answer1 > overlap_threshold or overlap_frac_answer2 > overlap_threshold: + keep = False + break + + return keep + + def deduplicate_by_overlap( + self, answers: List[ExtractedAnswer], overlap_threshold: Optional[float] + ) -> List[ExtractedAnswer]: + """ + This de-duplicates overlapping Extractive Answers from the same document based on how much the spans of the + answers overlap. + + :param answers: List of answers to be deduplicated. + :param overlap_threshold: If set this will remove duplicate answers if they have an overlap larger than the + supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove + one of these answers since the second answer has a 100% (1.0) overlap with the first answer. + However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so + both of these answers could be kept if this variable is set to 0.24 or lower. + If None is provided then all answers are kept. + """ + if overlap_threshold is None: + return answers + + # Initialize with the first answer and its offsets_in_document + deduplicated_answers = [answers[0]] + + # Loop over remaining answers to check for overlaps + for ans in answers[1:]: + keep = self._should_keep( + candidate_answer=ans, current_answers=deduplicated_answers, overlap_threshold=overlap_threshold + ) + if keep: + deduplicated_answers.append(ans) + + return deduplicated_answers + @component.output_types(answers=List[ExtractedAnswer]) def run( self, @@ -317,6 +434,7 @@ class ExtractiveReader: max_batch_size: Optional[int] = None, answers_per_seq: Optional[int] = None, no_answer: Optional[bool] = None, + overlap_threshold: Optional[float] = None, ): """ Locates and extracts answers from the given Documents using the given query. @@ -325,8 +443,6 @@ class ExtractiveReader: :param documents: List of Documents in which you want to search for an answer to the query. :param top_k: The maximum number of answers to return. An additional answer is returned if no_answer is set to True (default). - :param score_threshold: - :return: List of ExtractedAnswers sorted by (desc.) answer score. :param score_threshold: Returns only answers with the score above this threshold. :param max_seq_length: Maximum number of tokens. If a sequence exceeds it, the sequence is split. @@ -336,7 +452,17 @@ class ExtractiveReader: :param max_batch_size: Maximum number of samples that are fed through the model at the same time. :param answers_per_seq: Number of answer candidates to consider per sequence. This is relevant when a Document was split into multiple sequences because of max_seq_length. + Default: 20 :param no_answer: Whether to return no answer scores. + Default: True + :param overlap_threshold: If set this will remove duplicate answers if they have an overlap larger than the + supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove + one of these answers since the second answer has a 100% (1.0) overlap with the first answer. + However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so + both of these answers could be kept if this variable is set to 0.24 or lower. + If None is provided then all answers are kept. + Default: 0.01 + :return: List of ExtractedAnswers sorted by (desc.) answer score. """ queries = [query] # Temporary solution until we have decided what batching should look like in v2 nested_documents = [documents] @@ -348,8 +474,9 @@ class ExtractiveReader: max_seq_length = max_seq_length or self.max_seq_length stride = stride or self.stride max_batch_size = max_batch_size or self.max_batch_size - answers_per_seq = answers_per_seq or self.answers_per_seq or top_k or 20 + answers_per_seq = answers_per_seq or self.answers_per_seq or 20 no_answer = no_answer if no_answer is not None else self.no_answer + overlap_threshold = overlap_threshold or self.overlap_threshold flattened_queries, flattened_documents, query_ids = self._flatten_documents(queries, nested_documents) input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess( @@ -385,17 +512,18 @@ class ExtractiveReader: ) answers = self._nest_answers( - start, - end, - probabilities, - flattened_documents, - queries, - answers_per_seq, - top_k, - score_threshold, - query_ids, - document_ids, - no_answer, + start=start, + end=end, + probabilities=probabilities, + flattened_documents=flattened_documents, + queries=queries, + answers_per_seq=answers_per_seq, + top_k=top_k, + score_threshold=score_threshold, + query_ids=query_ids, + document_ids=document_ids, + no_answer=no_answer, + overlap_threshold=overlap_threshold, ) return {"answers": answers[0]} # same temporary batching fix as above diff --git a/releasenotes/notes/extractive-qa-answer-dedup-7ca3b94b79b38854.yaml b/releasenotes/notes/extractive-qa-answer-dedup-7ca3b94b79b38854.yaml new file mode 100644 index 000000000..12b79fa2a --- /dev/null +++ b/releasenotes/notes/extractive-qa-answer-dedup-7ca3b94b79b38854.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Introduces answer deduplication on the Document level based on an overlap threshold. diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index 69b51f5f9..3a62c33ad 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -7,7 +7,7 @@ import torch from transformers import pipeline from haystack.components.readers import ExtractiveReader -from haystack import Document +from haystack import Document, ExtractedAnswer @pytest.fixture @@ -233,8 +233,19 @@ def test_nest_answers(mock_reader: ExtractiveReader): probabilities = torch.arange(5).unsqueeze(0) / 5 + torch.arange(6).unsqueeze(-1) / 25 query_ids = [0] * 3 + [1] * 3 document_ids = list(range(3)) * 2 - nested_answers = mock_reader._nest_answers( - start, end, probabilities, example_documents[0], example_queries, 5, 3, None, query_ids, document_ids, True # type: ignore + nested_answers = mock_reader._nest_answers( # type: ignore + start=start, + end=end, + probabilities=probabilities, + flattened_documents=example_documents[0], + queries=example_queries, + answers_per_seq=5, + top_k=3, + score_threshold=None, + query_ids=query_ids, + document_ids=document_ids, + no_answer=True, + overlap_threshold=None, ) expected_no_answers = [0.2 * 0.16 * 0.12, 0] for query, answers, expected_no_answer, probabilities in zip( @@ -261,6 +272,196 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer): mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token") +class TestDeduplication: + @pytest.fixture + def doc1(self): + return Document(content="I want to go to the river in Maine.") + + @pytest.fixture + def doc2(self): + return Document(content="I want to go skiing in Colorado.") + + @pytest.fixture + def candidate_answer(self, doc1): + answer1 = "the river" + return ExtractedAnswer( + query="test", + data=answer1, + document=doc1, + document_offset=ExtractedAnswer.Span(doc1.content.find(answer1), doc1.content.find(answer1) + len(answer1)), + score=0.1, + meta={}, + ) + + def test_calculate_overlap(self, mock_reader: ExtractiveReader, doc1: Document): + answer1 = "the river" + answer2 = "river in Maine" + overlap_in_characters = mock_reader._calculate_overlap( + answer1_start=doc1.content.find(answer1), + answer1_end=doc1.content.find(answer1) + len(answer1), + answer2_start=doc1.content.find(answer2), + answer2_end=doc1.content.find(answer2) + len(answer2), + ) + assert overlap_in_characters == 5 + + def test_should_keep_false( + self, mock_reader: ExtractiveReader, doc1: Document, doc2: Document, candidate_answer: ExtractedAnswer + ): + answer2 = "river in Maine" + answer3 = "skiing in Colorado" + keep = mock_reader._should_keep( + candidate_answer=candidate_answer, + current_answers=[ + ExtractedAnswer( + query="test", + data=answer2, + document=doc1, + document_offset=ExtractedAnswer.Span( + doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2) + ), + score=0.1, + meta={}, + ), + ExtractedAnswer( + query="test", + data=answer3, + document=doc2, + document_offset=ExtractedAnswer.Span( + doc2.content.find(answer3), doc2.content.find(answer3) + len(answer3) + ), + score=0.1, + meta={}, + ), + ], + overlap_threshold=0.01, + ) + assert keep is False + + def test_should_keep_true( + self, mock_reader: ExtractiveReader, doc1: Document, doc2: Document, candidate_answer: ExtractedAnswer + ): + answer2 = "Maine" + answer3 = "skiing in Colorado" + keep = mock_reader._should_keep( + candidate_answer=candidate_answer, + current_answers=[ + ExtractedAnswer( + query="test", + data=answer2, + document=doc1, + document_offset=ExtractedAnswer.Span( + doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2) + ), + score=0.1, + meta={}, + ), + ExtractedAnswer( + query="test", + data=answer3, + document=doc2, + document_offset=ExtractedAnswer.Span( + doc2.content.find(answer3), doc2.content.find(answer3) + len(answer3) + ), + score=0.1, + meta={}, + ), + ], + overlap_threshold=0.01, + ) + assert keep is True + + def test_should_keep_missing_document_current_answer( + self, mock_reader: ExtractiveReader, doc1: Document, candidate_answer: ExtractedAnswer + ): + answer2 = "river in Maine" + keep = mock_reader._should_keep( + candidate_answer=candidate_answer, + current_answers=[ + ExtractedAnswer( + query="test", + data=answer2, + document=None, + document_offset=ExtractedAnswer.Span( + doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2) + ), + score=0.1, + meta={}, + ) + ], + overlap_threshold=0.01, + ) + assert keep is True + + def test_should_keep_missing_document_candidate_answer( + self, mock_reader: ExtractiveReader, doc1: Document, candidate_answer: ExtractedAnswer + ): + answer2 = "river in Maine" + keep = mock_reader._should_keep( + candidate_answer=ExtractedAnswer( + query="test", + data=answer2, + document=None, + document_offset=ExtractedAnswer.Span( + doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2) + ), + score=0.1, + meta={}, + ), + current_answers=[ + ExtractedAnswer( + query="test", + data=answer2, + document=doc1, + document_offset=ExtractedAnswer.Span( + doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2) + ), + score=0.1, + meta={}, + ) + ], + overlap_threshold=0.01, + ) + assert keep is True + + def test_should_keep_missing_span( + self, mock_reader: ExtractiveReader, doc1: Document, candidate_answer: ExtractedAnswer + ): + answer2 = "river in Maine" + keep = mock_reader._should_keep( + candidate_answer=candidate_answer, + current_answers=[ + ExtractedAnswer(query="test", data=answer2, document=doc1, document_offset=None, score=0.1, meta={}) + ], + overlap_threshold=0.01, + ) + assert keep is True + + def test_deduplicate_by_overlap_none_overlap( + self, mock_reader: ExtractiveReader, candidate_answer: ExtractedAnswer + ): + result = mock_reader.deduplicate_by_overlap( + answers=[candidate_answer, candidate_answer], overlap_threshold=None + ) + assert len(result) == 2 + + def test_deduplicate_by_overlap( + self, mock_reader: ExtractiveReader, candidate_answer: ExtractedAnswer, doc1: Document + ): + answer2 = "Maine" + extracted_answer2 = ExtractedAnswer( + query="test", + data=answer2, + document=doc1, + document_offset=ExtractedAnswer.Span(doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)), + score=0.1, + meta={}, + ) + result = mock_reader.deduplicate_by_overlap( + answers=[candidate_answer, candidate_answer, extracted_answer2], overlap_threshold=0.01 + ) + assert len(result) == 2 + + @pytest.mark.integration def test_t5(): reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad") @@ -274,6 +475,7 @@ def test_t5(): assert answers[1].score == pytest.approx(0.7703777551651001) assert answers[2].data is None assert answers[2].score == pytest.approx(0.051331606147570596) + assert len(answers) == 3 # Uncomment assertions below when batching is reintroduced # assert answers[0][2].score == pytest.approx(0.051331606147570596) # assert answers[1][0].data == "Jerry" @@ -297,6 +499,7 @@ def test_roberta(): assert answers[1].score == pytest.approx(0.857952892780304) assert answers[2].data is None assert answers[2].score == pytest.approx(0.019673851661650588) + assert len(answers) == 3 # uncomment assertions below when there is batching in v2 # assert answers[0][0].data == "Olaf Scholz" # assert answers[0][0].score == pytest.approx(0.8614975214004517) @@ -314,7 +517,7 @@ def test_roberta(): @pytest.mark.integration def test_matches_hf_pipeline(): - reader = ExtractiveReader("deepset/tinyroberta-squad2", device="cpu") + reader = ExtractiveReader("deepset/tinyroberta-squad2", device="cpu", overlap_threshold=None) reader.warm_up() answers = reader.run(example_queries[0], [[example_documents[0][0]]][0], top_k=20, no_answer=False)[ "answers"