feat: Deduplicate duplicate Answers resulting from overlapping Documents in FARMReader (#4470)

* Deduplicate answers resulting from document split overlap

* Add tests

* Fix Pylint

* Adapt existing test

* Incorporate PR feedback
This commit is contained in:
bogdankostic 2023-03-27 20:04:59 +02:00 committed by GitHub
parent de825ded1c
commit ed1837c0c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 222 additions and 19 deletions

View File

@ -207,6 +207,7 @@ class QACandidate:
# if the original text contained multiple consecutive whitespaces # if the original text contained multiple consecutive whitespaces
cleaned_final_text = final_text.strip() cleaned_final_text = final_text.strip()
if not cleaned_final_text: if not cleaned_final_text:
self.answer_type = "no_answer"
return "", 0, 0 return "", 0, 0
# Adjust the offsets in case of whitespace at the beginning of the answer # Adjust the offsets in case of whitespace at the beginning of the answer

View File

@ -95,7 +95,7 @@ class PreProcessor(BasePreProcessor):
In this case the id will be generated by using the content and the defined metadata. In this case the id will be generated by using the content and the defined metadata.
:param progress_bar: Whether to show a progress bar. :param progress_bar: Whether to show a progress bar.
:param add_page_number: Add the number of the page a paragraph occurs in to the Document's meta :param add_page_number: Add the number of the page a paragraph occurs in to the Document's meta
field `"page"`. Page boundaries are determined by `"\f"' character which is added field `"page"`. Page boundaries are determined by `"\f"` character which is added
in between pages by `PDFToTextConverter`, `TikaConverter`, `ParsrConverter` and in between pages by `PDFToTextConverter`, `TikaConverter`, `ParsrConverter` and
`AzureConverter`. `AzureConverter`.
:param max_chars_check: the maximum length a document is expected to have. Each document that is longer than max_chars_check in characters after pre-processing will raise a warning. :param max_chars_check: the maximum length a document is expected to have. Each document that is longer than max_chars_check in characters after pre-processing will raise a warning.
@ -371,6 +371,7 @@ class PreProcessor(BasePreProcessor):
splits_start_idxs=splits_start_idxs, splits_start_idxs=splits_start_idxs,
headlines=headlines, headlines=headlines,
meta=document.meta or {}, meta=document.meta or {},
split_overlap=split_overlap,
id_hash_keys=id_hash_keys, id_hash_keys=id_hash_keys,
) )
@ -482,21 +483,9 @@ class PreProcessor(BasePreProcessor):
splits_start_idxs.append(cur_start_idx) splits_start_idxs.append(cur_start_idx)
if split_overlap: if split_overlap:
overlap = [] processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice(
processed_sents = [] current_slice, split_length, split_overlap
word_count_overlap = 0 )
current_slice_copy = deepcopy(current_slice)
for idx, s in reversed(list(enumerate(current_slice))):
sen_len = len(s.split())
if word_count_overlap < split_overlap:
overlap.append(s)
word_count_overlap += sen_len
current_slice_copy.pop(idx)
else:
processed_sents = current_slice_copy
break
current_slice = list(reversed(overlap))
word_count_slice = word_count_overlap
else: else:
processed_sents = current_slice processed_sents = current_slice
current_slice = [] current_slice = []
@ -530,6 +519,35 @@ class PreProcessor(BasePreProcessor):
return text_splits, splits_pages, splits_start_idxs return text_splits, splits_pages, splits_start_idxs
@staticmethod
def _get_overlap_from_slice(
current_slice: List[str], split_length: int, split_overlap: int
) -> Tuple[List[str], List[str], int]:
"""
Returns a tuple with the following elements:
- processed_sents: List of sentences that are not overlapping the with next slice (= completely processed sentences)
- next_slice: List of sentences that are overlapping with the next slice
- word_count_slice: Number of words in the next slice
"""
overlap = []
word_count_overlap = 0
current_slice_copy = deepcopy(current_slice)
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
for idx, s in reversed(list(enumerate(current_slice))[1:]):
sen_len = len(s.split())
if word_count_overlap < split_overlap and sen_len < split_length:
overlap.append(s)
word_count_overlap += sen_len
current_slice_copy.pop(idx)
else:
break
processed_sents = current_slice_copy
next_slice = list(reversed(overlap))
word_count_slice = word_count_overlap
return processed_sents, next_slice, word_count_slice
def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]: def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
if split_by == "passage": if split_by == "passage":
elements = text.split("\n\n") elements = text.split("\n\n")
@ -580,12 +598,13 @@ class PreProcessor(BasePreProcessor):
splits_start_idxs: List[int], splits_start_idxs: List[int],
headlines: List[Dict], headlines: List[Dict],
meta: Dict, meta: Dict,
split_overlap: int,
id_hash_keys=Optional[List[str]], id_hash_keys=Optional[List[str]],
) -> List[Document]: ) -> List[Document]:
""" """
Creates Document objects from text splits enriching them with page number and headline information if given. Creates Document objects from text splits enriching them with page number and headline information if given.
""" """
documents = [] documents: List[Document] = []
earliest_rel_hl = 0 earliest_rel_hl = 0
for i, txt in enumerate(text_splits): for i, txt in enumerate(text_splits):
@ -600,11 +619,35 @@ class PreProcessor(BasePreProcessor):
headlines=headlines, split_txt=txt, split_start_idx=split_start_idx, earliest_rel_hl=earliest_rel_hl headlines=headlines, split_txt=txt, split_start_idx=split_start_idx, earliest_rel_hl=earliest_rel_hl
) )
doc.meta["headlines"] = relevant_headlines doc.meta["headlines"] = relevant_headlines
if split_overlap > 0:
doc.meta["_split_overlap"] = []
if i != 0:
doc_start_idx = splits_start_idxs[i]
previous_doc = documents[i - 1]
previous_doc_start_idx = splits_start_idxs[i - 1]
self._add_split_overlap_information(doc, doc_start_idx, previous_doc, previous_doc_start_idx)
documents.append(doc) documents.append(doc)
return documents return documents
@staticmethod
def _add_split_overlap_information(
current_doc: Document, current_doc_start_idx: int, previous_doc: Document, previos_doc_start_idx: int
):
"""
Adds split overlap information to the current and previous Document's meta.
"""
overlapping_range = (current_doc_start_idx - previos_doc_start_idx, len(previous_doc.content) - 1)
if overlapping_range[0] < overlapping_range[1]:
overlapping_str = previous_doc.content[overlapping_range[0] : overlapping_range[1]]
if current_doc.content.startswith(overlapping_str):
# Add split overlap information to previous Document regarding this Document
previous_doc.meta["_split_overlap"].append({"doc_id": current_doc.id, "range": overlapping_range})
# Add split overlap information to this Document regarding the previous Document
overlapping_range = (0, overlapping_range[1] - overlapping_range[0])
current_doc.meta["_split_overlap"].append({"doc_id": previous_doc.id, "range": overlapping_range})
@staticmethod @staticmethod
def _extract_relevant_headlines_for_split( def _extract_relevant_headlines_for_split(
headlines: List[Dict], split_txt: str, split_start_idx: int, earliest_rel_hl: int headlines: List[Dict], split_txt: str, split_start_idx: int, earliest_rel_hl: int
@ -752,7 +795,8 @@ class PreProcessor(BasePreProcessor):
period_context_fmt period_context_fmt
% { % {
"NonWord": sentence_tokenizer._lang_vars._re_non_word_chars, "NonWord": sentence_tokenizer._lang_vars._re_non_word_chars,
"SentEndChars": sentence_tokenizer._lang_vars._re_sent_end_chars, # SentEndChars might be followed by closing brackets, so we match them here.
"SentEndChars": sentence_tokenizer._lang_vars._re_sent_end_chars + r"[\)\]}]*",
}, },
re.UNICODE | re.VERBOSE, re.UNICODE | re.VERBOSE,
) )

View File

@ -913,6 +913,8 @@ class FARMReader(BaseReader):
predictions = self.inferencer.inference_from_objects( predictions = self.inferencer.inference_from_objects(
objects=inputs, return_json=False, multiprocessing_chunksize=1 objects=inputs, return_json=False, multiprocessing_chunksize=1
) )
# Deduplicate same answers resulting from Document split overlap
predictions = self._deduplicate_predictions(predictions, documents)
# assemble answers from all the different documents & format them. # assemble answers from all the different documents & format them.
answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k) answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k)
# TODO: potentially simplify return here to List[Answer] and handle no_ans_gap differently # TODO: potentially simplify return here to List[Answer] and handle no_ans_gap differently
@ -1262,6 +1264,82 @@ class FARMReader(BaseReader):
return inputs, number_of_docs, single_doc_list return inputs, number_of_docs, single_doc_list
def _deduplicate_predictions(self, predictions: List[QAPred], documents: List[Document]) -> List[QAPred]:
overlapping_docs = self._identify_overlapping_docs(documents)
if not overlapping_docs:
return predictions
preds_per_doc = {pred.id: pred for pred in predictions}
for pred in predictions:
# Check if current Document overlaps with Documents of other preds and continue if not
if pred.id not in overlapping_docs:
continue
relevant_overlaps = overlapping_docs[pred.id]
for ans_idx in reversed(range(len(pred.prediction))):
ans = pred.prediction[ans_idx]
if ans.answer_type != "span":
continue
for overlap in relevant_overlaps:
# Check if answer offsets are within the overlap
if not self._qa_cand_in_overlap(ans, overlap):
continue
# Check if predictions from overlapping Document are within the overlap
overlapping_doc_pred = preds_per_doc[overlap["doc_id"]]
cur_doc_overlap = [ol for ol in overlapping_docs[overlap["doc_id"]] if ol["doc_id"] == pred.id][0]
for pot_dupl_ans_idx in reversed(range(len(overlapping_doc_pred.prediction))):
pot_dupl_ans = overlapping_doc_pred.prediction[pot_dupl_ans_idx]
if pot_dupl_ans.answer_type != "span":
continue
if not self._qa_cand_in_overlap(pot_dupl_ans, cur_doc_overlap):
continue
# Check if ans and pot_dupl_ans are duplicates
if self._is_duplicate_answer(ans, overlap, pot_dupl_ans, cur_doc_overlap):
# Discard the duplicate with lower score
if ans.confidence < pot_dupl_ans.confidence:
pred.prediction.pop(ans_idx)
else:
overlapping_doc_pred.prediction.pop(pot_dupl_ans_idx)
return predictions
@staticmethod
def _is_duplicate_answer(
ans: QACandidate, ans_overlap: Dict, pot_dupl_ans: QACandidate, pot_dupl_ans_overlap: Dict
) -> bool:
answer_start_in_overlap = ans.offset_answer_start - ans_overlap["range"][0]
answer_end_in_overlap = ans.offset_answer_end - ans_overlap["range"][0]
pot_dupl_ans_start_in_overlap = pot_dupl_ans.offset_answer_start - pot_dupl_ans_overlap["range"][0]
pot_dupl_ans_end_in_overlap = pot_dupl_ans.offset_answer_end - pot_dupl_ans_overlap["range"][0]
return (
answer_start_in_overlap == pot_dupl_ans_start_in_overlap
and answer_end_in_overlap == pot_dupl_ans_end_in_overlap
)
@staticmethod
def _qa_cand_in_overlap(cand: QACandidate, overlap: Dict) -> bool:
if cand.offset_answer_start < overlap["range"][0] or cand.offset_answer_end > overlap["range"][1]:
return False
return True
@staticmethod
def _identify_overlapping_docs(documents: List[Document]) -> Dict[str, List]:
docs_by_ids = {doc.id: doc for doc in documents}
overlapping_docs = {}
for doc in documents:
if "_split_overlap" not in doc.meta:
continue
current_overlaps = [overlap for overlap in doc.meta["_split_overlap"] if overlap["doc_id"] in docs_by_ids]
if current_overlaps:
overlapping_docs[doc.id] = current_overlaps
return overlapping_docs
def calibrate_confidence_scores( def calibrate_confidence_scores(
self, self,
document_store: BaseDocumentStore, document_store: BaseDocumentStore,

View File

@ -251,7 +251,7 @@ def test_id_hash_keys_from_pipeline_params():
# test_input is a tuple consisting of the parameters for split_length, split_overlap and split_respect_sentence_boundary # test_input is a tuple consisting of the parameters for split_length, split_overlap and split_respect_sentence_boundary
# and the expected index in the output list of Documents where the page number changes from 1 to 2 # and the expected index in the output list of Documents where the page number changes from 1 to 2
@pytest.mark.unit @pytest.mark.unit
@pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 6), (10, 5, False, 7)]) @pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 5), (10, 5, False, 7)])
def test_page_number_extraction(test_input): def test_page_number_extraction(test_input):
split_length, overlap, resp_sent_boundary, exp_doc_index = test_input split_length, overlap, resp_sent_boundary, exp_doc_index = test_input
preprocessor = PreProcessor( preprocessor = PreProcessor(
@ -540,3 +540,62 @@ def test_preprocessor_very_long_document(caplog):
assert results == documents assert results == documents
for i in range(5): for i in range(5):
assert f"is 6{i} characters long after preprocessing, where the maximum length should be 10." in caplog.text assert f"is 6{i} characters long after preprocessing, where the maximum length should be 10." in caplog.text
@pytest.mark.unit
def test_split_respect_sentence_boundary_exceeding_split_len_not_repeated():
preproc = PreProcessor(split_length=13, split_overlap=3, split_by="word", split_respect_sentence_boundary=True)
document = Document(
content=(
"This is a test sentence with many many words that exceeds the split length and should not be repeated. "
"This is another test sentence. (This is a third test sentence.) "
"This is the last test sentence."
)
)
documents = preproc.process(document)
assert len(documents) == 3
assert (
documents[0].content
== "This is a test sentence with many many words that exceeds the split length and should not be repeated. "
)
assert "This is a test sentence with many many words" not in documents[1].content
assert "This is a test sentence with many many words" not in documents[2].content
@pytest.mark.unit
def test_split_overlap_information():
preproc = PreProcessor(split_length=13, split_overlap=3, split_by="word", split_respect_sentence_boundary=True)
document = Document(
content=(
"This is a test sentence with many many words that exceeds the split length and should not be repeated. "
"This is another test sentence. (This is a third test sentence.) This is the fourth sentence. "
"This is the last test sentence."
)
)
documents = preproc.process(document)
assert len(documents) == 4
# The first Document should not overlap with any other Document as it exceeds the split length, the other Documents
# should overlap with the previous Document (if applicable) and the next Document (if applicable)
assert len(documents[0].meta["_split_overlap"]) == 0
assert len(documents[1].meta["_split_overlap"]) == 1
assert len(documents[2].meta["_split_overlap"]) == 2
assert len(documents[3].meta["_split_overlap"]) == 1
assert documents[1].meta["_split_overlap"][0]["doc_id"] == documents[2].id
assert documents[2].meta["_split_overlap"][0]["doc_id"] == documents[1].id
assert documents[2].meta["_split_overlap"][1]["doc_id"] == documents[3].id
assert documents[3].meta["_split_overlap"][0]["doc_id"] == documents[2].id
doc1_overlap_doc2 = documents[1].meta["_split_overlap"][0]["range"]
doc2_overlap_doc1 = documents[2].meta["_split_overlap"][0]["range"]
assert (
documents[1].content[doc1_overlap_doc2[0] : doc1_overlap_doc2[1]]
== documents[2].content[doc2_overlap_doc1[0] : doc2_overlap_doc1[1]]
)
doc2_overlap_doc3 = documents[2].meta["_split_overlap"][1]["range"]
doc3_overlap_doc2 = documents[3].meta["_split_overlap"][0]["range"]
assert (
documents[2].content[doc2_overlap_doc3[0] : doc2_overlap_doc3[1]]
== documents[3].content[doc3_overlap_doc2[0] : doc3_overlap_doc2[1]]
)

View File

@ -144,6 +144,27 @@ def test_no_answer_output(no_answer_reader, docs):
assert len(no_answer_prediction["answers"]) == 5 assert len(no_answer_prediction["answers"]) == 5
@pytest.mark.integration
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_deduplication_for_overlapping_documents(reader):
docs = [
Document(
content="My name is Carla. I live in Berlin.",
id="doc1",
meta={"_split_id": 0, "_split_overlap": [{"doc_id": "doc2", "range": (18, 35)}]},
),
Document(
content="I live in Berlin. My friends call me Carla.",
id="doc2",
meta={"_split_id": 1, "_split_overlap": [{"doc_id": "doc1", "range": (0, 17)}]},
),
]
prediction = reader.predict(query="Where does Carla live?", documents=docs, top_k=5)
# Check that there are no duplicate answers
assert len(set(ans.answer for ans in prediction["answers"])) == len(prediction["answers"])
@pytest.mark.integration @pytest.mark.integration
def test_model_download_options(): def test_model_download_options():
# download disabled and model is not cached locally # download disabled and model is not cached locally