mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 08:33:51 +00:00
feat: Deduplicate duplicate Answers resulting from overlapping Documents in FARMReader
(#4470)
* Deduplicate answers resulting from document split overlap * Add tests * Fix Pylint * Adapt existing test * Incorporate PR feedback
This commit is contained in:
parent
de825ded1c
commit
ed1837c0c9
@ -207,6 +207,7 @@ class QACandidate:
|
||||
# if the original text contained multiple consecutive whitespaces
|
||||
cleaned_final_text = final_text.strip()
|
||||
if not cleaned_final_text:
|
||||
self.answer_type = "no_answer"
|
||||
return "", 0, 0
|
||||
|
||||
# Adjust the offsets in case of whitespace at the beginning of the answer
|
||||
|
@ -95,7 +95,7 @@ class PreProcessor(BasePreProcessor):
|
||||
In this case the id will be generated by using the content and the defined metadata.
|
||||
:param progress_bar: Whether to show a progress bar.
|
||||
:param add_page_number: Add the number of the page a paragraph occurs in to the Document's meta
|
||||
field `"page"`. Page boundaries are determined by `"\f"' character which is added
|
||||
field `"page"`. Page boundaries are determined by `"\f"` character which is added
|
||||
in between pages by `PDFToTextConverter`, `TikaConverter`, `ParsrConverter` and
|
||||
`AzureConverter`.
|
||||
:param max_chars_check: the maximum length a document is expected to have. Each document that is longer than max_chars_check in characters after pre-processing will raise a warning.
|
||||
@ -371,6 +371,7 @@ class PreProcessor(BasePreProcessor):
|
||||
splits_start_idxs=splits_start_idxs,
|
||||
headlines=headlines,
|
||||
meta=document.meta or {},
|
||||
split_overlap=split_overlap,
|
||||
id_hash_keys=id_hash_keys,
|
||||
)
|
||||
|
||||
@ -482,21 +483,9 @@ class PreProcessor(BasePreProcessor):
|
||||
splits_start_idxs.append(cur_start_idx)
|
||||
|
||||
if split_overlap:
|
||||
overlap = []
|
||||
processed_sents = []
|
||||
word_count_overlap = 0
|
||||
current_slice_copy = deepcopy(current_slice)
|
||||
for idx, s in reversed(list(enumerate(current_slice))):
|
||||
sen_len = len(s.split())
|
||||
if word_count_overlap < split_overlap:
|
||||
overlap.append(s)
|
||||
word_count_overlap += sen_len
|
||||
current_slice_copy.pop(idx)
|
||||
else:
|
||||
processed_sents = current_slice_copy
|
||||
break
|
||||
current_slice = list(reversed(overlap))
|
||||
word_count_slice = word_count_overlap
|
||||
processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice(
|
||||
current_slice, split_length, split_overlap
|
||||
)
|
||||
else:
|
||||
processed_sents = current_slice
|
||||
current_slice = []
|
||||
@ -530,6 +519,35 @@ class PreProcessor(BasePreProcessor):
|
||||
|
||||
return text_splits, splits_pages, splits_start_idxs
|
||||
|
||||
@staticmethod
|
||||
def _get_overlap_from_slice(
|
||||
current_slice: List[str], split_length: int, split_overlap: int
|
||||
) -> Tuple[List[str], List[str], int]:
|
||||
"""
|
||||
Returns a tuple with the following elements:
|
||||
- processed_sents: List of sentences that are not overlapping the with next slice (= completely processed sentences)
|
||||
- next_slice: List of sentences that are overlapping with the next slice
|
||||
- word_count_slice: Number of words in the next slice
|
||||
"""
|
||||
|
||||
overlap = []
|
||||
word_count_overlap = 0
|
||||
current_slice_copy = deepcopy(current_slice)
|
||||
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
|
||||
for idx, s in reversed(list(enumerate(current_slice))[1:]):
|
||||
sen_len = len(s.split())
|
||||
if word_count_overlap < split_overlap and sen_len < split_length:
|
||||
overlap.append(s)
|
||||
word_count_overlap += sen_len
|
||||
current_slice_copy.pop(idx)
|
||||
else:
|
||||
break
|
||||
processed_sents = current_slice_copy
|
||||
next_slice = list(reversed(overlap))
|
||||
word_count_slice = word_count_overlap
|
||||
|
||||
return processed_sents, next_slice, word_count_slice
|
||||
|
||||
def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
|
||||
if split_by == "passage":
|
||||
elements = text.split("\n\n")
|
||||
@ -580,12 +598,13 @@ class PreProcessor(BasePreProcessor):
|
||||
splits_start_idxs: List[int],
|
||||
headlines: List[Dict],
|
||||
meta: Dict,
|
||||
split_overlap: int,
|
||||
id_hash_keys=Optional[List[str]],
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Creates Document objects from text splits enriching them with page number and headline information if given.
|
||||
"""
|
||||
documents = []
|
||||
documents: List[Document] = []
|
||||
|
||||
earliest_rel_hl = 0
|
||||
for i, txt in enumerate(text_splits):
|
||||
@ -600,11 +619,35 @@ class PreProcessor(BasePreProcessor):
|
||||
headlines=headlines, split_txt=txt, split_start_idx=split_start_idx, earliest_rel_hl=earliest_rel_hl
|
||||
)
|
||||
doc.meta["headlines"] = relevant_headlines
|
||||
if split_overlap > 0:
|
||||
doc.meta["_split_overlap"] = []
|
||||
if i != 0:
|
||||
doc_start_idx = splits_start_idxs[i]
|
||||
previous_doc = documents[i - 1]
|
||||
previous_doc_start_idx = splits_start_idxs[i - 1]
|
||||
self._add_split_overlap_information(doc, doc_start_idx, previous_doc, previous_doc_start_idx)
|
||||
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def _add_split_overlap_information(
|
||||
current_doc: Document, current_doc_start_idx: int, previous_doc: Document, previos_doc_start_idx: int
|
||||
):
|
||||
"""
|
||||
Adds split overlap information to the current and previous Document's meta.
|
||||
"""
|
||||
overlapping_range = (current_doc_start_idx - previos_doc_start_idx, len(previous_doc.content) - 1)
|
||||
if overlapping_range[0] < overlapping_range[1]:
|
||||
overlapping_str = previous_doc.content[overlapping_range[0] : overlapping_range[1]]
|
||||
if current_doc.content.startswith(overlapping_str):
|
||||
# Add split overlap information to previous Document regarding this Document
|
||||
previous_doc.meta["_split_overlap"].append({"doc_id": current_doc.id, "range": overlapping_range})
|
||||
# Add split overlap information to this Document regarding the previous Document
|
||||
overlapping_range = (0, overlapping_range[1] - overlapping_range[0])
|
||||
current_doc.meta["_split_overlap"].append({"doc_id": previous_doc.id, "range": overlapping_range})
|
||||
|
||||
@staticmethod
|
||||
def _extract_relevant_headlines_for_split(
|
||||
headlines: List[Dict], split_txt: str, split_start_idx: int, earliest_rel_hl: int
|
||||
@ -752,7 +795,8 @@ class PreProcessor(BasePreProcessor):
|
||||
period_context_fmt
|
||||
% {
|
||||
"NonWord": sentence_tokenizer._lang_vars._re_non_word_chars,
|
||||
"SentEndChars": sentence_tokenizer._lang_vars._re_sent_end_chars,
|
||||
# SentEndChars might be followed by closing brackets, so we match them here.
|
||||
"SentEndChars": sentence_tokenizer._lang_vars._re_sent_end_chars + r"[\)\]}]*",
|
||||
},
|
||||
re.UNICODE | re.VERBOSE,
|
||||
)
|
||||
|
@ -913,6 +913,8 @@ class FARMReader(BaseReader):
|
||||
predictions = self.inferencer.inference_from_objects(
|
||||
objects=inputs, return_json=False, multiprocessing_chunksize=1
|
||||
)
|
||||
# Deduplicate same answers resulting from Document split overlap
|
||||
predictions = self._deduplicate_predictions(predictions, documents)
|
||||
# assemble answers from all the different documents & format them.
|
||||
answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k)
|
||||
# TODO: potentially simplify return here to List[Answer] and handle no_ans_gap differently
|
||||
@ -1262,6 +1264,82 @@ class FARMReader(BaseReader):
|
||||
|
||||
return inputs, number_of_docs, single_doc_list
|
||||
|
||||
def _deduplicate_predictions(self, predictions: List[QAPred], documents: List[Document]) -> List[QAPred]:
|
||||
overlapping_docs = self._identify_overlapping_docs(documents)
|
||||
if not overlapping_docs:
|
||||
return predictions
|
||||
|
||||
preds_per_doc = {pred.id: pred for pred in predictions}
|
||||
for pred in predictions:
|
||||
# Check if current Document overlaps with Documents of other preds and continue if not
|
||||
if pred.id not in overlapping_docs:
|
||||
continue
|
||||
|
||||
relevant_overlaps = overlapping_docs[pred.id]
|
||||
for ans_idx in reversed(range(len(pred.prediction))):
|
||||
ans = pred.prediction[ans_idx]
|
||||
if ans.answer_type != "span":
|
||||
continue
|
||||
|
||||
for overlap in relevant_overlaps:
|
||||
# Check if answer offsets are within the overlap
|
||||
if not self._qa_cand_in_overlap(ans, overlap):
|
||||
continue
|
||||
|
||||
# Check if predictions from overlapping Document are within the overlap
|
||||
overlapping_doc_pred = preds_per_doc[overlap["doc_id"]]
|
||||
cur_doc_overlap = [ol for ol in overlapping_docs[overlap["doc_id"]] if ol["doc_id"] == pred.id][0]
|
||||
for pot_dupl_ans_idx in reversed(range(len(overlapping_doc_pred.prediction))):
|
||||
pot_dupl_ans = overlapping_doc_pred.prediction[pot_dupl_ans_idx]
|
||||
if pot_dupl_ans.answer_type != "span":
|
||||
continue
|
||||
if not self._qa_cand_in_overlap(pot_dupl_ans, cur_doc_overlap):
|
||||
continue
|
||||
|
||||
# Check if ans and pot_dupl_ans are duplicates
|
||||
if self._is_duplicate_answer(ans, overlap, pot_dupl_ans, cur_doc_overlap):
|
||||
# Discard the duplicate with lower score
|
||||
if ans.confidence < pot_dupl_ans.confidence:
|
||||
pred.prediction.pop(ans_idx)
|
||||
else:
|
||||
overlapping_doc_pred.prediction.pop(pot_dupl_ans_idx)
|
||||
|
||||
return predictions
|
||||
|
||||
@staticmethod
|
||||
def _is_duplicate_answer(
|
||||
ans: QACandidate, ans_overlap: Dict, pot_dupl_ans: QACandidate, pot_dupl_ans_overlap: Dict
|
||||
) -> bool:
|
||||
answer_start_in_overlap = ans.offset_answer_start - ans_overlap["range"][0]
|
||||
answer_end_in_overlap = ans.offset_answer_end - ans_overlap["range"][0]
|
||||
|
||||
pot_dupl_ans_start_in_overlap = pot_dupl_ans.offset_answer_start - pot_dupl_ans_overlap["range"][0]
|
||||
pot_dupl_ans_end_in_overlap = pot_dupl_ans.offset_answer_end - pot_dupl_ans_overlap["range"][0]
|
||||
|
||||
return (
|
||||
answer_start_in_overlap == pot_dupl_ans_start_in_overlap
|
||||
and answer_end_in_overlap == pot_dupl_ans_end_in_overlap
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _qa_cand_in_overlap(cand: QACandidate, overlap: Dict) -> bool:
|
||||
if cand.offset_answer_start < overlap["range"][0] or cand.offset_answer_end > overlap["range"][1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _identify_overlapping_docs(documents: List[Document]) -> Dict[str, List]:
|
||||
docs_by_ids = {doc.id: doc for doc in documents}
|
||||
overlapping_docs = {}
|
||||
for doc in documents:
|
||||
if "_split_overlap" not in doc.meta:
|
||||
continue
|
||||
current_overlaps = [overlap for overlap in doc.meta["_split_overlap"] if overlap["doc_id"] in docs_by_ids]
|
||||
if current_overlaps:
|
||||
overlapping_docs[doc.id] = current_overlaps
|
||||
|
||||
return overlapping_docs
|
||||
|
||||
def calibrate_confidence_scores(
|
||||
self,
|
||||
document_store: BaseDocumentStore,
|
||||
|
@ -251,7 +251,7 @@ def test_id_hash_keys_from_pipeline_params():
|
||||
# test_input is a tuple consisting of the parameters for split_length, split_overlap and split_respect_sentence_boundary
|
||||
# and the expected index in the output list of Documents where the page number changes from 1 to 2
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 6), (10, 5, False, 7)])
|
||||
@pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 5), (10, 5, False, 7)])
|
||||
def test_page_number_extraction(test_input):
|
||||
split_length, overlap, resp_sent_boundary, exp_doc_index = test_input
|
||||
preprocessor = PreProcessor(
|
||||
@ -540,3 +540,62 @@ def test_preprocessor_very_long_document(caplog):
|
||||
assert results == documents
|
||||
for i in range(5):
|
||||
assert f"is 6{i} characters long after preprocessing, where the maximum length should be 10." in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_split_respect_sentence_boundary_exceeding_split_len_not_repeated():
|
||||
preproc = PreProcessor(split_length=13, split_overlap=3, split_by="word", split_respect_sentence_boundary=True)
|
||||
document = Document(
|
||||
content=(
|
||||
"This is a test sentence with many many words that exceeds the split length and should not be repeated. "
|
||||
"This is another test sentence. (This is a third test sentence.) "
|
||||
"This is the last test sentence."
|
||||
)
|
||||
)
|
||||
documents = preproc.process(document)
|
||||
assert len(documents) == 3
|
||||
assert (
|
||||
documents[0].content
|
||||
== "This is a test sentence with many many words that exceeds the split length and should not be repeated. "
|
||||
)
|
||||
assert "This is a test sentence with many many words" not in documents[1].content
|
||||
assert "This is a test sentence with many many words" not in documents[2].content
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_split_overlap_information():
|
||||
preproc = PreProcessor(split_length=13, split_overlap=3, split_by="word", split_respect_sentence_boundary=True)
|
||||
document = Document(
|
||||
content=(
|
||||
"This is a test sentence with many many words that exceeds the split length and should not be repeated. "
|
||||
"This is another test sentence. (This is a third test sentence.) This is the fourth sentence. "
|
||||
"This is the last test sentence."
|
||||
)
|
||||
)
|
||||
documents = preproc.process(document)
|
||||
assert len(documents) == 4
|
||||
# The first Document should not overlap with any other Document as it exceeds the split length, the other Documents
|
||||
# should overlap with the previous Document (if applicable) and the next Document (if applicable)
|
||||
assert len(documents[0].meta["_split_overlap"]) == 0
|
||||
assert len(documents[1].meta["_split_overlap"]) == 1
|
||||
assert len(documents[2].meta["_split_overlap"]) == 2
|
||||
assert len(documents[3].meta["_split_overlap"]) == 1
|
||||
|
||||
assert documents[1].meta["_split_overlap"][0]["doc_id"] == documents[2].id
|
||||
assert documents[2].meta["_split_overlap"][0]["doc_id"] == documents[1].id
|
||||
assert documents[2].meta["_split_overlap"][1]["doc_id"] == documents[3].id
|
||||
assert documents[3].meta["_split_overlap"][0]["doc_id"] == documents[2].id
|
||||
|
||||
doc1_overlap_doc2 = documents[1].meta["_split_overlap"][0]["range"]
|
||||
doc2_overlap_doc1 = documents[2].meta["_split_overlap"][0]["range"]
|
||||
assert (
|
||||
documents[1].content[doc1_overlap_doc2[0] : doc1_overlap_doc2[1]]
|
||||
== documents[2].content[doc2_overlap_doc1[0] : doc2_overlap_doc1[1]]
|
||||
)
|
||||
|
||||
doc2_overlap_doc3 = documents[2].meta["_split_overlap"][1]["range"]
|
||||
doc3_overlap_doc2 = documents[3].meta["_split_overlap"][0]["range"]
|
||||
assert (
|
||||
documents[2].content[doc2_overlap_doc3[0] : doc2_overlap_doc3[1]]
|
||||
== documents[3].content[doc3_overlap_doc2[0] : doc3_overlap_doc2[1]]
|
||||
)
|
||||
|
@ -144,6 +144,27 @@ def test_no_answer_output(no_answer_reader, docs):
|
||||
assert len(no_answer_prediction["answers"]) == 5
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_deduplication_for_overlapping_documents(reader):
|
||||
docs = [
|
||||
Document(
|
||||
content="My name is Carla. I live in Berlin.",
|
||||
id="doc1",
|
||||
meta={"_split_id": 0, "_split_overlap": [{"doc_id": "doc2", "range": (18, 35)}]},
|
||||
),
|
||||
Document(
|
||||
content="I live in Berlin. My friends call me Carla.",
|
||||
id="doc2",
|
||||
meta={"_split_id": 1, "_split_overlap": [{"doc_id": "doc1", "range": (0, 17)}]},
|
||||
),
|
||||
]
|
||||
prediction = reader.predict(query="Where does Carla live?", documents=docs, top_k=5)
|
||||
|
||||
# Check that there are no duplicate answers
|
||||
assert len(set(ans.answer for ans in prediction["answers"])) == len(prediction["answers"])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_model_download_options():
|
||||
# download disabled and model is not cached locally
|
||||
|
Loading…
x
Reference in New Issue
Block a user