feat: Deduplicate duplicate Answers resulting from overlapping Documents in FARMReader (#4470)

* Deduplicate answers resulting from document split overlap

* Add tests

* Fix Pylint

* Adapt existing test

* Incorporate PR feedback
This commit is contained in:
bogdankostic 2023-03-27 20:04:59 +02:00 committed by GitHub
parent de825ded1c
commit ed1837c0c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 222 additions and 19 deletions

View File

@ -207,6 +207,7 @@ class QACandidate:
# if the original text contained multiple consecutive whitespaces
cleaned_final_text = final_text.strip()
if not cleaned_final_text:
self.answer_type = "no_answer"
return "", 0, 0
# Adjust the offsets in case of whitespace at the beginning of the answer

View File

@ -95,7 +95,7 @@ class PreProcessor(BasePreProcessor):
In this case the id will be generated by using the content and the defined metadata.
:param progress_bar: Whether to show a progress bar.
:param add_page_number: Add the number of the page a paragraph occurs in to the Document's meta
field `"page"`. Page boundaries are determined by `"\f"' character which is added
field `"page"`. Page boundaries are determined by `"\f"` character which is added
in between pages by `PDFToTextConverter`, `TikaConverter`, `ParsrConverter` and
`AzureConverter`.
:param max_chars_check: the maximum length a document is expected to have. Each document that is longer than max_chars_check in characters after pre-processing will raise a warning.
@ -371,6 +371,7 @@ class PreProcessor(BasePreProcessor):
splits_start_idxs=splits_start_idxs,
headlines=headlines,
meta=document.meta or {},
split_overlap=split_overlap,
id_hash_keys=id_hash_keys,
)
@ -482,21 +483,9 @@ class PreProcessor(BasePreProcessor):
splits_start_idxs.append(cur_start_idx)
if split_overlap:
overlap = []
processed_sents = []
word_count_overlap = 0
current_slice_copy = deepcopy(current_slice)
for idx, s in reversed(list(enumerate(current_slice))):
sen_len = len(s.split())
if word_count_overlap < split_overlap:
overlap.append(s)
word_count_overlap += sen_len
current_slice_copy.pop(idx)
else:
processed_sents = current_slice_copy
break
current_slice = list(reversed(overlap))
word_count_slice = word_count_overlap
processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice(
current_slice, split_length, split_overlap
)
else:
processed_sents = current_slice
current_slice = []
@ -530,6 +519,35 @@ class PreProcessor(BasePreProcessor):
return text_splits, splits_pages, splits_start_idxs
@staticmethod
def _get_overlap_from_slice(
current_slice: List[str], split_length: int, split_overlap: int
) -> Tuple[List[str], List[str], int]:
"""
Returns a tuple with the following elements:
- processed_sents: List of sentences that are not overlapping the with next slice (= completely processed sentences)
- next_slice: List of sentences that are overlapping with the next slice
- word_count_slice: Number of words in the next slice
"""
overlap = []
word_count_overlap = 0
current_slice_copy = deepcopy(current_slice)
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
for idx, s in reversed(list(enumerate(current_slice))[1:]):
sen_len = len(s.split())
if word_count_overlap < split_overlap and sen_len < split_length:
overlap.append(s)
word_count_overlap += sen_len
current_slice_copy.pop(idx)
else:
break
processed_sents = current_slice_copy
next_slice = list(reversed(overlap))
word_count_slice = word_count_overlap
return processed_sents, next_slice, word_count_slice
def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
if split_by == "passage":
elements = text.split("\n\n")
@ -580,12 +598,13 @@ class PreProcessor(BasePreProcessor):
splits_start_idxs: List[int],
headlines: List[Dict],
meta: Dict,
split_overlap: int,
id_hash_keys=Optional[List[str]],
) -> List[Document]:
"""
Creates Document objects from text splits enriching them with page number and headline information if given.
"""
documents = []
documents: List[Document] = []
earliest_rel_hl = 0
for i, txt in enumerate(text_splits):
@ -600,11 +619,35 @@ class PreProcessor(BasePreProcessor):
headlines=headlines, split_txt=txt, split_start_idx=split_start_idx, earliest_rel_hl=earliest_rel_hl
)
doc.meta["headlines"] = relevant_headlines
if split_overlap > 0:
doc.meta["_split_overlap"] = []
if i != 0:
doc_start_idx = splits_start_idxs[i]
previous_doc = documents[i - 1]
previous_doc_start_idx = splits_start_idxs[i - 1]
self._add_split_overlap_information(doc, doc_start_idx, previous_doc, previous_doc_start_idx)
documents.append(doc)
return documents
@staticmethod
def _add_split_overlap_information(
current_doc: Document, current_doc_start_idx: int, previous_doc: Document, previos_doc_start_idx: int
):
"""
Adds split overlap information to the current and previous Document's meta.
"""
overlapping_range = (current_doc_start_idx - previos_doc_start_idx, len(previous_doc.content) - 1)
if overlapping_range[0] < overlapping_range[1]:
overlapping_str = previous_doc.content[overlapping_range[0] : overlapping_range[1]]
if current_doc.content.startswith(overlapping_str):
# Add split overlap information to previous Document regarding this Document
previous_doc.meta["_split_overlap"].append({"doc_id": current_doc.id, "range": overlapping_range})
# Add split overlap information to this Document regarding the previous Document
overlapping_range = (0, overlapping_range[1] - overlapping_range[0])
current_doc.meta["_split_overlap"].append({"doc_id": previous_doc.id, "range": overlapping_range})
@staticmethod
def _extract_relevant_headlines_for_split(
headlines: List[Dict], split_txt: str, split_start_idx: int, earliest_rel_hl: int
@ -752,7 +795,8 @@ class PreProcessor(BasePreProcessor):
period_context_fmt
% {
"NonWord": sentence_tokenizer._lang_vars._re_non_word_chars,
"SentEndChars": sentence_tokenizer._lang_vars._re_sent_end_chars,
# SentEndChars might be followed by closing brackets, so we match them here.
"SentEndChars": sentence_tokenizer._lang_vars._re_sent_end_chars + r"[\)\]}]*",
},
re.UNICODE | re.VERBOSE,
)

View File

@ -913,6 +913,8 @@ class FARMReader(BaseReader):
predictions = self.inferencer.inference_from_objects(
objects=inputs, return_json=False, multiprocessing_chunksize=1
)
# Deduplicate same answers resulting from Document split overlap
predictions = self._deduplicate_predictions(predictions, documents)
# assemble answers from all the different documents & format them.
answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k)
# TODO: potentially simplify return here to List[Answer] and handle no_ans_gap differently
@ -1262,6 +1264,82 @@ class FARMReader(BaseReader):
return inputs, number_of_docs, single_doc_list
def _deduplicate_predictions(self, predictions: List[QAPred], documents: List[Document]) -> List[QAPred]:
overlapping_docs = self._identify_overlapping_docs(documents)
if not overlapping_docs:
return predictions
preds_per_doc = {pred.id: pred for pred in predictions}
for pred in predictions:
# Check if current Document overlaps with Documents of other preds and continue if not
if pred.id not in overlapping_docs:
continue
relevant_overlaps = overlapping_docs[pred.id]
for ans_idx in reversed(range(len(pred.prediction))):
ans = pred.prediction[ans_idx]
if ans.answer_type != "span":
continue
for overlap in relevant_overlaps:
# Check if answer offsets are within the overlap
if not self._qa_cand_in_overlap(ans, overlap):
continue
# Check if predictions from overlapping Document are within the overlap
overlapping_doc_pred = preds_per_doc[overlap["doc_id"]]
cur_doc_overlap = [ol for ol in overlapping_docs[overlap["doc_id"]] if ol["doc_id"] == pred.id][0]
for pot_dupl_ans_idx in reversed(range(len(overlapping_doc_pred.prediction))):
pot_dupl_ans = overlapping_doc_pred.prediction[pot_dupl_ans_idx]
if pot_dupl_ans.answer_type != "span":
continue
if not self._qa_cand_in_overlap(pot_dupl_ans, cur_doc_overlap):
continue
# Check if ans and pot_dupl_ans are duplicates
if self._is_duplicate_answer(ans, overlap, pot_dupl_ans, cur_doc_overlap):
# Discard the duplicate with lower score
if ans.confidence < pot_dupl_ans.confidence:
pred.prediction.pop(ans_idx)
else:
overlapping_doc_pred.prediction.pop(pot_dupl_ans_idx)
return predictions
@staticmethod
def _is_duplicate_answer(
ans: QACandidate, ans_overlap: Dict, pot_dupl_ans: QACandidate, pot_dupl_ans_overlap: Dict
) -> bool:
answer_start_in_overlap = ans.offset_answer_start - ans_overlap["range"][0]
answer_end_in_overlap = ans.offset_answer_end - ans_overlap["range"][0]
pot_dupl_ans_start_in_overlap = pot_dupl_ans.offset_answer_start - pot_dupl_ans_overlap["range"][0]
pot_dupl_ans_end_in_overlap = pot_dupl_ans.offset_answer_end - pot_dupl_ans_overlap["range"][0]
return (
answer_start_in_overlap == pot_dupl_ans_start_in_overlap
and answer_end_in_overlap == pot_dupl_ans_end_in_overlap
)
@staticmethod
def _qa_cand_in_overlap(cand: QACandidate, overlap: Dict) -> bool:
if cand.offset_answer_start < overlap["range"][0] or cand.offset_answer_end > overlap["range"][1]:
return False
return True
@staticmethod
def _identify_overlapping_docs(documents: List[Document]) -> Dict[str, List]:
docs_by_ids = {doc.id: doc for doc in documents}
overlapping_docs = {}
for doc in documents:
if "_split_overlap" not in doc.meta:
continue
current_overlaps = [overlap for overlap in doc.meta["_split_overlap"] if overlap["doc_id"] in docs_by_ids]
if current_overlaps:
overlapping_docs[doc.id] = current_overlaps
return overlapping_docs
def calibrate_confidence_scores(
self,
document_store: BaseDocumentStore,

View File

@ -251,7 +251,7 @@ def test_id_hash_keys_from_pipeline_params():
# test_input is a tuple consisting of the parameters for split_length, split_overlap and split_respect_sentence_boundary
# and the expected index in the output list of Documents where the page number changes from 1 to 2
@pytest.mark.unit
@pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 6), (10, 5, False, 7)])
@pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 5), (10, 5, False, 7)])
def test_page_number_extraction(test_input):
split_length, overlap, resp_sent_boundary, exp_doc_index = test_input
preprocessor = PreProcessor(
@ -540,3 +540,62 @@ def test_preprocessor_very_long_document(caplog):
assert results == documents
for i in range(5):
assert f"is 6{i} characters long after preprocessing, where the maximum length should be 10." in caplog.text
@pytest.mark.unit
def test_split_respect_sentence_boundary_exceeding_split_len_not_repeated():
preproc = PreProcessor(split_length=13, split_overlap=3, split_by="word", split_respect_sentence_boundary=True)
document = Document(
content=(
"This is a test sentence with many many words that exceeds the split length and should not be repeated. "
"This is another test sentence. (This is a third test sentence.) "
"This is the last test sentence."
)
)
documents = preproc.process(document)
assert len(documents) == 3
assert (
documents[0].content
== "This is a test sentence with many many words that exceeds the split length and should not be repeated. "
)
assert "This is a test sentence with many many words" not in documents[1].content
assert "This is a test sentence with many many words" not in documents[2].content
@pytest.mark.unit
def test_split_overlap_information():
preproc = PreProcessor(split_length=13, split_overlap=3, split_by="word", split_respect_sentence_boundary=True)
document = Document(
content=(
"This is a test sentence with many many words that exceeds the split length and should not be repeated. "
"This is another test sentence. (This is a third test sentence.) This is the fourth sentence. "
"This is the last test sentence."
)
)
documents = preproc.process(document)
assert len(documents) == 4
# The first Document should not overlap with any other Document as it exceeds the split length, the other Documents
# should overlap with the previous Document (if applicable) and the next Document (if applicable)
assert len(documents[0].meta["_split_overlap"]) == 0
assert len(documents[1].meta["_split_overlap"]) == 1
assert len(documents[2].meta["_split_overlap"]) == 2
assert len(documents[3].meta["_split_overlap"]) == 1
assert documents[1].meta["_split_overlap"][0]["doc_id"] == documents[2].id
assert documents[2].meta["_split_overlap"][0]["doc_id"] == documents[1].id
assert documents[2].meta["_split_overlap"][1]["doc_id"] == documents[3].id
assert documents[3].meta["_split_overlap"][0]["doc_id"] == documents[2].id
doc1_overlap_doc2 = documents[1].meta["_split_overlap"][0]["range"]
doc2_overlap_doc1 = documents[2].meta["_split_overlap"][0]["range"]
assert (
documents[1].content[doc1_overlap_doc2[0] : doc1_overlap_doc2[1]]
== documents[2].content[doc2_overlap_doc1[0] : doc2_overlap_doc1[1]]
)
doc2_overlap_doc3 = documents[2].meta["_split_overlap"][1]["range"]
doc3_overlap_doc2 = documents[3].meta["_split_overlap"][0]["range"]
assert (
documents[2].content[doc2_overlap_doc3[0] : doc2_overlap_doc3[1]]
== documents[3].content[doc3_overlap_doc2[0] : doc3_overlap_doc2[1]]
)

View File

@ -144,6 +144,27 @@ def test_no_answer_output(no_answer_reader, docs):
assert len(no_answer_prediction["answers"]) == 5
@pytest.mark.integration
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_deduplication_for_overlapping_documents(reader):
docs = [
Document(
content="My name is Carla. I live in Berlin.",
id="doc1",
meta={"_split_id": 0, "_split_overlap": [{"doc_id": "doc2", "range": (18, 35)}]},
),
Document(
content="I live in Berlin. My friends call me Carla.",
id="doc2",
meta={"_split_id": 1, "_split_overlap": [{"doc_id": "doc1", "range": (0, 17)}]},
),
]
prediction = reader.predict(query="Where does Carla live?", documents=docs, top_k=5)
# Check that there are no duplicate answers
assert len(set(ans.answer for ans in prediction["answers"])) == len(prediction["answers"])
@pytest.mark.integration
def test_model_download_options():
# download disabled and model is not cached locally