diff --git a/haystack/modeling/model/predictions.py b/haystack/modeling/model/predictions.py index baa5dfc72..db74c4e59 100644 --- a/haystack/modeling/model/predictions.py +++ b/haystack/modeling/model/predictions.py @@ -207,6 +207,7 @@ class QACandidate: # if the original text contained multiple consecutive whitespaces cleaned_final_text = final_text.strip() if not cleaned_final_text: + self.answer_type = "no_answer" return "", 0, 0 # Adjust the offsets in case of whitespace at the beginning of the answer diff --git a/haystack/nodes/preprocessor/preprocessor.py b/haystack/nodes/preprocessor/preprocessor.py index e384f4c3b..935f60372 100644 --- a/haystack/nodes/preprocessor/preprocessor.py +++ b/haystack/nodes/preprocessor/preprocessor.py @@ -95,7 +95,7 @@ class PreProcessor(BasePreProcessor): In this case the id will be generated by using the content and the defined metadata. :param progress_bar: Whether to show a progress bar. :param add_page_number: Add the number of the page a paragraph occurs in to the Document's meta - field `"page"`. Page boundaries are determined by `"\f"' character which is added + field `"page"`. Page boundaries are determined by `"\f"` character which is added in between pages by `PDFToTextConverter`, `TikaConverter`, `ParsrConverter` and `AzureConverter`. :param max_chars_check: the maximum length a document is expected to have. Each document that is longer than max_chars_check in characters after pre-processing will raise a warning. @@ -371,6 +371,7 @@ class PreProcessor(BasePreProcessor): splits_start_idxs=splits_start_idxs, headlines=headlines, meta=document.meta or {}, + split_overlap=split_overlap, id_hash_keys=id_hash_keys, ) @@ -482,21 +483,9 @@ class PreProcessor(BasePreProcessor): splits_start_idxs.append(cur_start_idx) if split_overlap: - overlap = [] - processed_sents = [] - word_count_overlap = 0 - current_slice_copy = deepcopy(current_slice) - for idx, s in reversed(list(enumerate(current_slice))): - sen_len = len(s.split()) - if word_count_overlap < split_overlap: - overlap.append(s) - word_count_overlap += sen_len - current_slice_copy.pop(idx) - else: - processed_sents = current_slice_copy - break - current_slice = list(reversed(overlap)) - word_count_slice = word_count_overlap + processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice( + current_slice, split_length, split_overlap + ) else: processed_sents = current_slice current_slice = [] @@ -530,6 +519,35 @@ class PreProcessor(BasePreProcessor): return text_splits, splits_pages, splits_start_idxs + @staticmethod + def _get_overlap_from_slice( + current_slice: List[str], split_length: int, split_overlap: int + ) -> Tuple[List[str], List[str], int]: + """ + Returns a tuple with the following elements: + - processed_sents: List of sentences that are not overlapping the with next slice (= completely processed sentences) + - next_slice: List of sentences that are overlapping with the next slice + - word_count_slice: Number of words in the next slice + """ + + overlap = [] + word_count_overlap = 0 + current_slice_copy = deepcopy(current_slice) + # Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence + for idx, s in reversed(list(enumerate(current_slice))[1:]): + sen_len = len(s.split()) + if word_count_overlap < split_overlap and sen_len < split_length: + overlap.append(s) + word_count_overlap += sen_len + current_slice_copy.pop(idx) + else: + break + processed_sents = current_slice_copy + next_slice = list(reversed(overlap)) + word_count_slice = word_count_overlap + + return processed_sents, next_slice, word_count_slice + def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]: if split_by == "passage": elements = text.split("\n\n") @@ -580,12 +598,13 @@ class PreProcessor(BasePreProcessor): splits_start_idxs: List[int], headlines: List[Dict], meta: Dict, + split_overlap: int, id_hash_keys=Optional[List[str]], ) -> List[Document]: """ Creates Document objects from text splits enriching them with page number and headline information if given. """ - documents = [] + documents: List[Document] = [] earliest_rel_hl = 0 for i, txt in enumerate(text_splits): @@ -600,11 +619,35 @@ class PreProcessor(BasePreProcessor): headlines=headlines, split_txt=txt, split_start_idx=split_start_idx, earliest_rel_hl=earliest_rel_hl ) doc.meta["headlines"] = relevant_headlines + if split_overlap > 0: + doc.meta["_split_overlap"] = [] + if i != 0: + doc_start_idx = splits_start_idxs[i] + previous_doc = documents[i - 1] + previous_doc_start_idx = splits_start_idxs[i - 1] + self._add_split_overlap_information(doc, doc_start_idx, previous_doc, previous_doc_start_idx) documents.append(doc) return documents + @staticmethod + def _add_split_overlap_information( + current_doc: Document, current_doc_start_idx: int, previous_doc: Document, previos_doc_start_idx: int + ): + """ + Adds split overlap information to the current and previous Document's meta. + """ + overlapping_range = (current_doc_start_idx - previos_doc_start_idx, len(previous_doc.content) - 1) + if overlapping_range[0] < overlapping_range[1]: + overlapping_str = previous_doc.content[overlapping_range[0] : overlapping_range[1]] + if current_doc.content.startswith(overlapping_str): + # Add split overlap information to previous Document regarding this Document + previous_doc.meta["_split_overlap"].append({"doc_id": current_doc.id, "range": overlapping_range}) + # Add split overlap information to this Document regarding the previous Document + overlapping_range = (0, overlapping_range[1] - overlapping_range[0]) + current_doc.meta["_split_overlap"].append({"doc_id": previous_doc.id, "range": overlapping_range}) + @staticmethod def _extract_relevant_headlines_for_split( headlines: List[Dict], split_txt: str, split_start_idx: int, earliest_rel_hl: int @@ -752,7 +795,8 @@ class PreProcessor(BasePreProcessor): period_context_fmt % { "NonWord": sentence_tokenizer._lang_vars._re_non_word_chars, - "SentEndChars": sentence_tokenizer._lang_vars._re_sent_end_chars, + # SentEndChars might be followed by closing brackets, so we match them here. + "SentEndChars": sentence_tokenizer._lang_vars._re_sent_end_chars + r"[\)\]}]*", }, re.UNICODE | re.VERBOSE, ) diff --git a/haystack/nodes/reader/farm.py b/haystack/nodes/reader/farm.py index 6a0a825cc..f97e3a236 100644 --- a/haystack/nodes/reader/farm.py +++ b/haystack/nodes/reader/farm.py @@ -913,6 +913,8 @@ class FARMReader(BaseReader): predictions = self.inferencer.inference_from_objects( objects=inputs, return_json=False, multiprocessing_chunksize=1 ) + # Deduplicate same answers resulting from Document split overlap + predictions = self._deduplicate_predictions(predictions, documents) # assemble answers from all the different documents & format them. answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k) # TODO: potentially simplify return here to List[Answer] and handle no_ans_gap differently @@ -1262,6 +1264,82 @@ class FARMReader(BaseReader): return inputs, number_of_docs, single_doc_list + def _deduplicate_predictions(self, predictions: List[QAPred], documents: List[Document]) -> List[QAPred]: + overlapping_docs = self._identify_overlapping_docs(documents) + if not overlapping_docs: + return predictions + + preds_per_doc = {pred.id: pred for pred in predictions} + for pred in predictions: + # Check if current Document overlaps with Documents of other preds and continue if not + if pred.id not in overlapping_docs: + continue + + relevant_overlaps = overlapping_docs[pred.id] + for ans_idx in reversed(range(len(pred.prediction))): + ans = pred.prediction[ans_idx] + if ans.answer_type != "span": + continue + + for overlap in relevant_overlaps: + # Check if answer offsets are within the overlap + if not self._qa_cand_in_overlap(ans, overlap): + continue + + # Check if predictions from overlapping Document are within the overlap + overlapping_doc_pred = preds_per_doc[overlap["doc_id"]] + cur_doc_overlap = [ol for ol in overlapping_docs[overlap["doc_id"]] if ol["doc_id"] == pred.id][0] + for pot_dupl_ans_idx in reversed(range(len(overlapping_doc_pred.prediction))): + pot_dupl_ans = overlapping_doc_pred.prediction[pot_dupl_ans_idx] + if pot_dupl_ans.answer_type != "span": + continue + if not self._qa_cand_in_overlap(pot_dupl_ans, cur_doc_overlap): + continue + + # Check if ans and pot_dupl_ans are duplicates + if self._is_duplicate_answer(ans, overlap, pot_dupl_ans, cur_doc_overlap): + # Discard the duplicate with lower score + if ans.confidence < pot_dupl_ans.confidence: + pred.prediction.pop(ans_idx) + else: + overlapping_doc_pred.prediction.pop(pot_dupl_ans_idx) + + return predictions + + @staticmethod + def _is_duplicate_answer( + ans: QACandidate, ans_overlap: Dict, pot_dupl_ans: QACandidate, pot_dupl_ans_overlap: Dict + ) -> bool: + answer_start_in_overlap = ans.offset_answer_start - ans_overlap["range"][0] + answer_end_in_overlap = ans.offset_answer_end - ans_overlap["range"][0] + + pot_dupl_ans_start_in_overlap = pot_dupl_ans.offset_answer_start - pot_dupl_ans_overlap["range"][0] + pot_dupl_ans_end_in_overlap = pot_dupl_ans.offset_answer_end - pot_dupl_ans_overlap["range"][0] + + return ( + answer_start_in_overlap == pot_dupl_ans_start_in_overlap + and answer_end_in_overlap == pot_dupl_ans_end_in_overlap + ) + + @staticmethod + def _qa_cand_in_overlap(cand: QACandidate, overlap: Dict) -> bool: + if cand.offset_answer_start < overlap["range"][0] or cand.offset_answer_end > overlap["range"][1]: + return False + return True + + @staticmethod + def _identify_overlapping_docs(documents: List[Document]) -> Dict[str, List]: + docs_by_ids = {doc.id: doc for doc in documents} + overlapping_docs = {} + for doc in documents: + if "_split_overlap" not in doc.meta: + continue + current_overlaps = [overlap for overlap in doc.meta["_split_overlap"] if overlap["doc_id"] in docs_by_ids] + if current_overlaps: + overlapping_docs[doc.id] = current_overlaps + + return overlapping_docs + def calibrate_confidence_scores( self, document_store: BaseDocumentStore, diff --git a/test/nodes/test_preprocessor.py b/test/nodes/test_preprocessor.py index 5600f2439..5ccc46ac8 100644 --- a/test/nodes/test_preprocessor.py +++ b/test/nodes/test_preprocessor.py @@ -251,7 +251,7 @@ def test_id_hash_keys_from_pipeline_params(): # test_input is a tuple consisting of the parameters for split_length, split_overlap and split_respect_sentence_boundary # and the expected index in the output list of Documents where the page number changes from 1 to 2 @pytest.mark.unit -@pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 6), (10, 5, False, 7)]) +@pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 5), (10, 5, False, 7)]) def test_page_number_extraction(test_input): split_length, overlap, resp_sent_boundary, exp_doc_index = test_input preprocessor = PreProcessor( @@ -540,3 +540,62 @@ def test_preprocessor_very_long_document(caplog): assert results == documents for i in range(5): assert f"is 6{i} characters long after preprocessing, where the maximum length should be 10." in caplog.text + + +@pytest.mark.unit +def test_split_respect_sentence_boundary_exceeding_split_len_not_repeated(): + preproc = PreProcessor(split_length=13, split_overlap=3, split_by="word", split_respect_sentence_boundary=True) + document = Document( + content=( + "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + "This is another test sentence. (This is a third test sentence.) " + "This is the last test sentence." + ) + ) + documents = preproc.process(document) + assert len(documents) == 3 + assert ( + documents[0].content + == "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + ) + assert "This is a test sentence with many many words" not in documents[1].content + assert "This is a test sentence with many many words" not in documents[2].content + + +@pytest.mark.unit +def test_split_overlap_information(): + preproc = PreProcessor(split_length=13, split_overlap=3, split_by="word", split_respect_sentence_boundary=True) + document = Document( + content=( + "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + "This is another test sentence. (This is a third test sentence.) This is the fourth sentence. " + "This is the last test sentence." + ) + ) + documents = preproc.process(document) + assert len(documents) == 4 + # The first Document should not overlap with any other Document as it exceeds the split length, the other Documents + # should overlap with the previous Document (if applicable) and the next Document (if applicable) + assert len(documents[0].meta["_split_overlap"]) == 0 + assert len(documents[1].meta["_split_overlap"]) == 1 + assert len(documents[2].meta["_split_overlap"]) == 2 + assert len(documents[3].meta["_split_overlap"]) == 1 + + assert documents[1].meta["_split_overlap"][0]["doc_id"] == documents[2].id + assert documents[2].meta["_split_overlap"][0]["doc_id"] == documents[1].id + assert documents[2].meta["_split_overlap"][1]["doc_id"] == documents[3].id + assert documents[3].meta["_split_overlap"][0]["doc_id"] == documents[2].id + + doc1_overlap_doc2 = documents[1].meta["_split_overlap"][0]["range"] + doc2_overlap_doc1 = documents[2].meta["_split_overlap"][0]["range"] + assert ( + documents[1].content[doc1_overlap_doc2[0] : doc1_overlap_doc2[1]] + == documents[2].content[doc2_overlap_doc1[0] : doc2_overlap_doc1[1]] + ) + + doc2_overlap_doc3 = documents[2].meta["_split_overlap"][1]["range"] + doc3_overlap_doc2 = documents[3].meta["_split_overlap"][0]["range"] + assert ( + documents[2].content[doc2_overlap_doc3[0] : doc2_overlap_doc3[1]] + == documents[3].content[doc3_overlap_doc2[0] : doc3_overlap_doc2[1]] + ) diff --git a/test/nodes/test_reader.py b/test/nodes/test_reader.py index c21076978..cb68d6d0c 100644 --- a/test/nodes/test_reader.py +++ b/test/nodes/test_reader.py @@ -144,6 +144,27 @@ def test_no_answer_output(no_answer_reader, docs): assert len(no_answer_prediction["answers"]) == 5 +@pytest.mark.integration +@pytest.mark.parametrize("reader", ["farm"], indirect=True) +def test_deduplication_for_overlapping_documents(reader): + docs = [ + Document( + content="My name is Carla. I live in Berlin.", + id="doc1", + meta={"_split_id": 0, "_split_overlap": [{"doc_id": "doc2", "range": (18, 35)}]}, + ), + Document( + content="I live in Berlin. My friends call me Carla.", + id="doc2", + meta={"_split_id": 1, "_split_overlap": [{"doc_id": "doc1", "range": (0, 17)}]}, + ), + ] + prediction = reader.predict(query="Where does Carla live?", documents=docs, top_k=5) + + # Check that there are no duplicate answers + assert len(set(ans.answer for ans in prediction["answers"])) == len(prediction["answers"]) + + @pytest.mark.integration def test_model_download_options(): # download disabled and model is not cached locally