mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-28 09:35:42 +00:00
feat: Deduplicate duplicate Answers resulting from overlapping Documents in FARMReader
(#4470)
* Deduplicate answers resulting from document split overlap * Add tests * Fix Pylint * Adapt existing test * Incorporate PR feedback
This commit is contained in:
parent
de825ded1c
commit
ed1837c0c9
@ -207,6 +207,7 @@ class QACandidate:
|
|||||||
# if the original text contained multiple consecutive whitespaces
|
# if the original text contained multiple consecutive whitespaces
|
||||||
cleaned_final_text = final_text.strip()
|
cleaned_final_text = final_text.strip()
|
||||||
if not cleaned_final_text:
|
if not cleaned_final_text:
|
||||||
|
self.answer_type = "no_answer"
|
||||||
return "", 0, 0
|
return "", 0, 0
|
||||||
|
|
||||||
# Adjust the offsets in case of whitespace at the beginning of the answer
|
# Adjust the offsets in case of whitespace at the beginning of the answer
|
||||||
|
@ -95,7 +95,7 @@ class PreProcessor(BasePreProcessor):
|
|||||||
In this case the id will be generated by using the content and the defined metadata.
|
In this case the id will be generated by using the content and the defined metadata.
|
||||||
:param progress_bar: Whether to show a progress bar.
|
:param progress_bar: Whether to show a progress bar.
|
||||||
:param add_page_number: Add the number of the page a paragraph occurs in to the Document's meta
|
:param add_page_number: Add the number of the page a paragraph occurs in to the Document's meta
|
||||||
field `"page"`. Page boundaries are determined by `"\f"' character which is added
|
field `"page"`. Page boundaries are determined by `"\f"` character which is added
|
||||||
in between pages by `PDFToTextConverter`, `TikaConverter`, `ParsrConverter` and
|
in between pages by `PDFToTextConverter`, `TikaConverter`, `ParsrConverter` and
|
||||||
`AzureConverter`.
|
`AzureConverter`.
|
||||||
:param max_chars_check: the maximum length a document is expected to have. Each document that is longer than max_chars_check in characters after pre-processing will raise a warning.
|
:param max_chars_check: the maximum length a document is expected to have. Each document that is longer than max_chars_check in characters after pre-processing will raise a warning.
|
||||||
@ -371,6 +371,7 @@ class PreProcessor(BasePreProcessor):
|
|||||||
splits_start_idxs=splits_start_idxs,
|
splits_start_idxs=splits_start_idxs,
|
||||||
headlines=headlines,
|
headlines=headlines,
|
||||||
meta=document.meta or {},
|
meta=document.meta or {},
|
||||||
|
split_overlap=split_overlap,
|
||||||
id_hash_keys=id_hash_keys,
|
id_hash_keys=id_hash_keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -482,21 +483,9 @@ class PreProcessor(BasePreProcessor):
|
|||||||
splits_start_idxs.append(cur_start_idx)
|
splits_start_idxs.append(cur_start_idx)
|
||||||
|
|
||||||
if split_overlap:
|
if split_overlap:
|
||||||
overlap = []
|
processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice(
|
||||||
processed_sents = []
|
current_slice, split_length, split_overlap
|
||||||
word_count_overlap = 0
|
)
|
||||||
current_slice_copy = deepcopy(current_slice)
|
|
||||||
for idx, s in reversed(list(enumerate(current_slice))):
|
|
||||||
sen_len = len(s.split())
|
|
||||||
if word_count_overlap < split_overlap:
|
|
||||||
overlap.append(s)
|
|
||||||
word_count_overlap += sen_len
|
|
||||||
current_slice_copy.pop(idx)
|
|
||||||
else:
|
|
||||||
processed_sents = current_slice_copy
|
|
||||||
break
|
|
||||||
current_slice = list(reversed(overlap))
|
|
||||||
word_count_slice = word_count_overlap
|
|
||||||
else:
|
else:
|
||||||
processed_sents = current_slice
|
processed_sents = current_slice
|
||||||
current_slice = []
|
current_slice = []
|
||||||
@ -530,6 +519,35 @@ class PreProcessor(BasePreProcessor):
|
|||||||
|
|
||||||
return text_splits, splits_pages, splits_start_idxs
|
return text_splits, splits_pages, splits_start_idxs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_overlap_from_slice(
|
||||||
|
current_slice: List[str], split_length: int, split_overlap: int
|
||||||
|
) -> Tuple[List[str], List[str], int]:
|
||||||
|
"""
|
||||||
|
Returns a tuple with the following elements:
|
||||||
|
- processed_sents: List of sentences that are not overlapping the with next slice (= completely processed sentences)
|
||||||
|
- next_slice: List of sentences that are overlapping with the next slice
|
||||||
|
- word_count_slice: Number of words in the next slice
|
||||||
|
"""
|
||||||
|
|
||||||
|
overlap = []
|
||||||
|
word_count_overlap = 0
|
||||||
|
current_slice_copy = deepcopy(current_slice)
|
||||||
|
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
|
||||||
|
for idx, s in reversed(list(enumerate(current_slice))[1:]):
|
||||||
|
sen_len = len(s.split())
|
||||||
|
if word_count_overlap < split_overlap and sen_len < split_length:
|
||||||
|
overlap.append(s)
|
||||||
|
word_count_overlap += sen_len
|
||||||
|
current_slice_copy.pop(idx)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
processed_sents = current_slice_copy
|
||||||
|
next_slice = list(reversed(overlap))
|
||||||
|
word_count_slice = word_count_overlap
|
||||||
|
|
||||||
|
return processed_sents, next_slice, word_count_slice
|
||||||
|
|
||||||
def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
|
def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
|
||||||
if split_by == "passage":
|
if split_by == "passage":
|
||||||
elements = text.split("\n\n")
|
elements = text.split("\n\n")
|
||||||
@ -580,12 +598,13 @@ class PreProcessor(BasePreProcessor):
|
|||||||
splits_start_idxs: List[int],
|
splits_start_idxs: List[int],
|
||||||
headlines: List[Dict],
|
headlines: List[Dict],
|
||||||
meta: Dict,
|
meta: Dict,
|
||||||
|
split_overlap: int,
|
||||||
id_hash_keys=Optional[List[str]],
|
id_hash_keys=Optional[List[str]],
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Creates Document objects from text splits enriching them with page number and headline information if given.
|
Creates Document objects from text splits enriching them with page number and headline information if given.
|
||||||
"""
|
"""
|
||||||
documents = []
|
documents: List[Document] = []
|
||||||
|
|
||||||
earliest_rel_hl = 0
|
earliest_rel_hl = 0
|
||||||
for i, txt in enumerate(text_splits):
|
for i, txt in enumerate(text_splits):
|
||||||
@ -600,11 +619,35 @@ class PreProcessor(BasePreProcessor):
|
|||||||
headlines=headlines, split_txt=txt, split_start_idx=split_start_idx, earliest_rel_hl=earliest_rel_hl
|
headlines=headlines, split_txt=txt, split_start_idx=split_start_idx, earliest_rel_hl=earliest_rel_hl
|
||||||
)
|
)
|
||||||
doc.meta["headlines"] = relevant_headlines
|
doc.meta["headlines"] = relevant_headlines
|
||||||
|
if split_overlap > 0:
|
||||||
|
doc.meta["_split_overlap"] = []
|
||||||
|
if i != 0:
|
||||||
|
doc_start_idx = splits_start_idxs[i]
|
||||||
|
previous_doc = documents[i - 1]
|
||||||
|
previous_doc_start_idx = splits_start_idxs[i - 1]
|
||||||
|
self._add_split_overlap_information(doc, doc_start_idx, previous_doc, previous_doc_start_idx)
|
||||||
|
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_split_overlap_information(
|
||||||
|
current_doc: Document, current_doc_start_idx: int, previous_doc: Document, previos_doc_start_idx: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds split overlap information to the current and previous Document's meta.
|
||||||
|
"""
|
||||||
|
overlapping_range = (current_doc_start_idx - previos_doc_start_idx, len(previous_doc.content) - 1)
|
||||||
|
if overlapping_range[0] < overlapping_range[1]:
|
||||||
|
overlapping_str = previous_doc.content[overlapping_range[0] : overlapping_range[1]]
|
||||||
|
if current_doc.content.startswith(overlapping_str):
|
||||||
|
# Add split overlap information to previous Document regarding this Document
|
||||||
|
previous_doc.meta["_split_overlap"].append({"doc_id": current_doc.id, "range": overlapping_range})
|
||||||
|
# Add split overlap information to this Document regarding the previous Document
|
||||||
|
overlapping_range = (0, overlapping_range[1] - overlapping_range[0])
|
||||||
|
current_doc.meta["_split_overlap"].append({"doc_id": previous_doc.id, "range": overlapping_range})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_relevant_headlines_for_split(
|
def _extract_relevant_headlines_for_split(
|
||||||
headlines: List[Dict], split_txt: str, split_start_idx: int, earliest_rel_hl: int
|
headlines: List[Dict], split_txt: str, split_start_idx: int, earliest_rel_hl: int
|
||||||
@ -752,7 +795,8 @@ class PreProcessor(BasePreProcessor):
|
|||||||
period_context_fmt
|
period_context_fmt
|
||||||
% {
|
% {
|
||||||
"NonWord": sentence_tokenizer._lang_vars._re_non_word_chars,
|
"NonWord": sentence_tokenizer._lang_vars._re_non_word_chars,
|
||||||
"SentEndChars": sentence_tokenizer._lang_vars._re_sent_end_chars,
|
# SentEndChars might be followed by closing brackets, so we match them here.
|
||||||
|
"SentEndChars": sentence_tokenizer._lang_vars._re_sent_end_chars + r"[\)\]}]*",
|
||||||
},
|
},
|
||||||
re.UNICODE | re.VERBOSE,
|
re.UNICODE | re.VERBOSE,
|
||||||
)
|
)
|
||||||
|
@ -913,6 +913,8 @@ class FARMReader(BaseReader):
|
|||||||
predictions = self.inferencer.inference_from_objects(
|
predictions = self.inferencer.inference_from_objects(
|
||||||
objects=inputs, return_json=False, multiprocessing_chunksize=1
|
objects=inputs, return_json=False, multiprocessing_chunksize=1
|
||||||
)
|
)
|
||||||
|
# Deduplicate same answers resulting from Document split overlap
|
||||||
|
predictions = self._deduplicate_predictions(predictions, documents)
|
||||||
# assemble answers from all the different documents & format them.
|
# assemble answers from all the different documents & format them.
|
||||||
answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k)
|
answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k)
|
||||||
# TODO: potentially simplify return here to List[Answer] and handle no_ans_gap differently
|
# TODO: potentially simplify return here to List[Answer] and handle no_ans_gap differently
|
||||||
@ -1262,6 +1264,82 @@ class FARMReader(BaseReader):
|
|||||||
|
|
||||||
return inputs, number_of_docs, single_doc_list
|
return inputs, number_of_docs, single_doc_list
|
||||||
|
|
||||||
|
def _deduplicate_predictions(self, predictions: List[QAPred], documents: List[Document]) -> List[QAPred]:
|
||||||
|
overlapping_docs = self._identify_overlapping_docs(documents)
|
||||||
|
if not overlapping_docs:
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
preds_per_doc = {pred.id: pred for pred in predictions}
|
||||||
|
for pred in predictions:
|
||||||
|
# Check if current Document overlaps with Documents of other preds and continue if not
|
||||||
|
if pred.id not in overlapping_docs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
relevant_overlaps = overlapping_docs[pred.id]
|
||||||
|
for ans_idx in reversed(range(len(pred.prediction))):
|
||||||
|
ans = pred.prediction[ans_idx]
|
||||||
|
if ans.answer_type != "span":
|
||||||
|
continue
|
||||||
|
|
||||||
|
for overlap in relevant_overlaps:
|
||||||
|
# Check if answer offsets are within the overlap
|
||||||
|
if not self._qa_cand_in_overlap(ans, overlap):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if predictions from overlapping Document are within the overlap
|
||||||
|
overlapping_doc_pred = preds_per_doc[overlap["doc_id"]]
|
||||||
|
cur_doc_overlap = [ol for ol in overlapping_docs[overlap["doc_id"]] if ol["doc_id"] == pred.id][0]
|
||||||
|
for pot_dupl_ans_idx in reversed(range(len(overlapping_doc_pred.prediction))):
|
||||||
|
pot_dupl_ans = overlapping_doc_pred.prediction[pot_dupl_ans_idx]
|
||||||
|
if pot_dupl_ans.answer_type != "span":
|
||||||
|
continue
|
||||||
|
if not self._qa_cand_in_overlap(pot_dupl_ans, cur_doc_overlap):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if ans and pot_dupl_ans are duplicates
|
||||||
|
if self._is_duplicate_answer(ans, overlap, pot_dupl_ans, cur_doc_overlap):
|
||||||
|
# Discard the duplicate with lower score
|
||||||
|
if ans.confidence < pot_dupl_ans.confidence:
|
||||||
|
pred.prediction.pop(ans_idx)
|
||||||
|
else:
|
||||||
|
overlapping_doc_pred.prediction.pop(pot_dupl_ans_idx)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_duplicate_answer(
|
||||||
|
ans: QACandidate, ans_overlap: Dict, pot_dupl_ans: QACandidate, pot_dupl_ans_overlap: Dict
|
||||||
|
) -> bool:
|
||||||
|
answer_start_in_overlap = ans.offset_answer_start - ans_overlap["range"][0]
|
||||||
|
answer_end_in_overlap = ans.offset_answer_end - ans_overlap["range"][0]
|
||||||
|
|
||||||
|
pot_dupl_ans_start_in_overlap = pot_dupl_ans.offset_answer_start - pot_dupl_ans_overlap["range"][0]
|
||||||
|
pot_dupl_ans_end_in_overlap = pot_dupl_ans.offset_answer_end - pot_dupl_ans_overlap["range"][0]
|
||||||
|
|
||||||
|
return (
|
||||||
|
answer_start_in_overlap == pot_dupl_ans_start_in_overlap
|
||||||
|
and answer_end_in_overlap == pot_dupl_ans_end_in_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _qa_cand_in_overlap(cand: QACandidate, overlap: Dict) -> bool:
|
||||||
|
if cand.offset_answer_start < overlap["range"][0] or cand.offset_answer_end > overlap["range"][1]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _identify_overlapping_docs(documents: List[Document]) -> Dict[str, List]:
|
||||||
|
docs_by_ids = {doc.id: doc for doc in documents}
|
||||||
|
overlapping_docs = {}
|
||||||
|
for doc in documents:
|
||||||
|
if "_split_overlap" not in doc.meta:
|
||||||
|
continue
|
||||||
|
current_overlaps = [overlap for overlap in doc.meta["_split_overlap"] if overlap["doc_id"] in docs_by_ids]
|
||||||
|
if current_overlaps:
|
||||||
|
overlapping_docs[doc.id] = current_overlaps
|
||||||
|
|
||||||
|
return overlapping_docs
|
||||||
|
|
||||||
def calibrate_confidence_scores(
|
def calibrate_confidence_scores(
|
||||||
self,
|
self,
|
||||||
document_store: BaseDocumentStore,
|
document_store: BaseDocumentStore,
|
||||||
|
@ -251,7 +251,7 @@ def test_id_hash_keys_from_pipeline_params():
|
|||||||
# test_input is a tuple consisting of the parameters for split_length, split_overlap and split_respect_sentence_boundary
|
# test_input is a tuple consisting of the parameters for split_length, split_overlap and split_respect_sentence_boundary
|
||||||
# and the expected index in the output list of Documents where the page number changes from 1 to 2
|
# and the expected index in the output list of Documents where the page number changes from 1 to 2
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 6), (10, 5, False, 7)])
|
@pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 5), (10, 5, False, 7)])
|
||||||
def test_page_number_extraction(test_input):
|
def test_page_number_extraction(test_input):
|
||||||
split_length, overlap, resp_sent_boundary, exp_doc_index = test_input
|
split_length, overlap, resp_sent_boundary, exp_doc_index = test_input
|
||||||
preprocessor = PreProcessor(
|
preprocessor = PreProcessor(
|
||||||
@ -540,3 +540,62 @@ def test_preprocessor_very_long_document(caplog):
|
|||||||
assert results == documents
|
assert results == documents
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
assert f"is 6{i} characters long after preprocessing, where the maximum length should be 10." in caplog.text
|
assert f"is 6{i} characters long after preprocessing, where the maximum length should be 10." in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_split_respect_sentence_boundary_exceeding_split_len_not_repeated():
|
||||||
|
preproc = PreProcessor(split_length=13, split_overlap=3, split_by="word", split_respect_sentence_boundary=True)
|
||||||
|
document = Document(
|
||||||
|
content=(
|
||||||
|
"This is a test sentence with many many words that exceeds the split length and should not be repeated. "
|
||||||
|
"This is another test sentence. (This is a third test sentence.) "
|
||||||
|
"This is the last test sentence."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
documents = preproc.process(document)
|
||||||
|
assert len(documents) == 3
|
||||||
|
assert (
|
||||||
|
documents[0].content
|
||||||
|
== "This is a test sentence with many many words that exceeds the split length and should not be repeated. "
|
||||||
|
)
|
||||||
|
assert "This is a test sentence with many many words" not in documents[1].content
|
||||||
|
assert "This is a test sentence with many many words" not in documents[2].content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_split_overlap_information():
|
||||||
|
preproc = PreProcessor(split_length=13, split_overlap=3, split_by="word", split_respect_sentence_boundary=True)
|
||||||
|
document = Document(
|
||||||
|
content=(
|
||||||
|
"This is a test sentence with many many words that exceeds the split length and should not be repeated. "
|
||||||
|
"This is another test sentence. (This is a third test sentence.) This is the fourth sentence. "
|
||||||
|
"This is the last test sentence."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
documents = preproc.process(document)
|
||||||
|
assert len(documents) == 4
|
||||||
|
# The first Document should not overlap with any other Document as it exceeds the split length, the other Documents
|
||||||
|
# should overlap with the previous Document (if applicable) and the next Document (if applicable)
|
||||||
|
assert len(documents[0].meta["_split_overlap"]) == 0
|
||||||
|
assert len(documents[1].meta["_split_overlap"]) == 1
|
||||||
|
assert len(documents[2].meta["_split_overlap"]) == 2
|
||||||
|
assert len(documents[3].meta["_split_overlap"]) == 1
|
||||||
|
|
||||||
|
assert documents[1].meta["_split_overlap"][0]["doc_id"] == documents[2].id
|
||||||
|
assert documents[2].meta["_split_overlap"][0]["doc_id"] == documents[1].id
|
||||||
|
assert documents[2].meta["_split_overlap"][1]["doc_id"] == documents[3].id
|
||||||
|
assert documents[3].meta["_split_overlap"][0]["doc_id"] == documents[2].id
|
||||||
|
|
||||||
|
doc1_overlap_doc2 = documents[1].meta["_split_overlap"][0]["range"]
|
||||||
|
doc2_overlap_doc1 = documents[2].meta["_split_overlap"][0]["range"]
|
||||||
|
assert (
|
||||||
|
documents[1].content[doc1_overlap_doc2[0] : doc1_overlap_doc2[1]]
|
||||||
|
== documents[2].content[doc2_overlap_doc1[0] : doc2_overlap_doc1[1]]
|
||||||
|
)
|
||||||
|
|
||||||
|
doc2_overlap_doc3 = documents[2].meta["_split_overlap"][1]["range"]
|
||||||
|
doc3_overlap_doc2 = documents[3].meta["_split_overlap"][0]["range"]
|
||||||
|
assert (
|
||||||
|
documents[2].content[doc2_overlap_doc3[0] : doc2_overlap_doc3[1]]
|
||||||
|
== documents[3].content[doc3_overlap_doc2[0] : doc3_overlap_doc2[1]]
|
||||||
|
)
|
||||||
|
@ -144,6 +144,27 @@ def test_no_answer_output(no_answer_reader, docs):
|
|||||||
assert len(no_answer_prediction["answers"]) == 5
|
assert len(no_answer_prediction["answers"]) == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||||
|
def test_deduplication_for_overlapping_documents(reader):
|
||||||
|
docs = [
|
||||||
|
Document(
|
||||||
|
content="My name is Carla. I live in Berlin.",
|
||||||
|
id="doc1",
|
||||||
|
meta={"_split_id": 0, "_split_overlap": [{"doc_id": "doc2", "range": (18, 35)}]},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
content="I live in Berlin. My friends call me Carla.",
|
||||||
|
id="doc2",
|
||||||
|
meta={"_split_id": 1, "_split_overlap": [{"doc_id": "doc1", "range": (0, 17)}]},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
prediction = reader.predict(query="Where does Carla live?", documents=docs, top_k=5)
|
||||||
|
|
||||||
|
# Check that there are no duplicate answers
|
||||||
|
assert len(set(ans.answer for ans in prediction["answers"])) == len(prediction["answers"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_model_download_options():
|
def test_model_download_options():
|
||||||
# download disabled and model is not cached locally
|
# download disabled and model is not cached locally
|
||||||
|
Loading…
x
Reference in New Issue
Block a user