From a2c160e7d8e706cd8184eb984db5882350d9d876 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Tue, 3 Jan 2023 15:50:14 +0100 Subject: [PATCH] bug: skip empty documents in reader (#3773) * skip empty documents * test eval_batch and account for tables --- haystack/modeling/infer.py | 4 ++ haystack/nodes/reader/base.py | 59 ++++++++++++++++++++++----- haystack/nodes/reader/farm.py | 1 - haystack/nodes/reader/transformers.py | 14 ++++--- test/nodes/test_reader.py | 57 +++++++++++++++++++++++++- test/pipelines/test_eval.py | 25 ++++++++++++ 6 files changed, 142 insertions(+), 18 deletions(-) diff --git a/haystack/modeling/infer.py b/haystack/modeling/infer.py index 315727197..dcc953ee6 100644 --- a/haystack/modeling/infer.py +++ b/haystack/modeling/infer.py @@ -508,6 +508,10 @@ class QAInferencer(Inferencer): This parameter has no effect; it will be removed as Inferencer multiprocessing has been deprecated. """ + # Return no predictions if there are no inputs + if not objects: + return [] + dicts = [o.to_dict() for o in objects] # TODO investigate this deprecation warning. Timo: I thought we were about to implement Input Objects, # then we can and should use inference from (input) objects! diff --git a/haystack/nodes/reader/base.py b/haystack/nodes/reader/base.py index b7c524c6e..78ce13b9c 100644 --- a/haystack/nodes/reader/base.py +++ b/haystack/nodes/reader/base.py @@ -40,17 +40,25 @@ class BaseReader(BaseComponent): # the most significant difference between scores. # Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa). # No_ans_gap is a list of this most significant difference per document - no_ans_gap_array = np.array(no_ans_gaps) - max_no_ans_gap = np.max(no_ans_gap_array) - # case 1: all passages "no answer" as top score - # max_no_ans_gap is negative, so it increases best pos score - # case 2: at least one passage predicts an answer (positive no_ans_gap) - no_ans_score = best_score_answer - max_no_ans_gap + + # If there is not even one predicted answer, we return a no_answer with score 1.0 + if best_score_answer == 0 and len(no_ans_gaps) == 0: + no_ans_score = 1024.0 + no_ans_score_scaled = 1.0 + max_no_ans_gap = 1024.0 + else: + no_ans_gap_array = np.array(no_ans_gaps) + max_no_ans_gap = np.max(no_ans_gap_array) + # case 1: all passages "no answer" as top score + # max_no_ans_gap is negative, so it increases best pos score + # case 2: at least one passage predicts an answer (positive no_ans_gap) + no_ans_score = best_score_answer - max_no_ans_gap + no_ans_score_scaled = float(expit(np.asarray(no_ans_score) / 8)) no_ans_prediction = Answer( answer="", type="extractive", - score=float(expit(np.asarray(no_ans_score) / 8)) + score=no_ans_score_scaled if use_confidence_scores else no_ans_score, # just a pseudo prob for now or old score, context=None, @@ -80,10 +88,27 @@ class BaseReader(BaseComponent): def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, labels: Optional[MultiLabel] = None, add_isolated_node_eval: bool = False): # type: ignore self.query_count += 1 predict = self.timing(self.predict, "query_time") + # Remove empty text documents before making predictions + documents = [d for d in documents if not isinstance(d.content, str) or d.content.strip() != ""] if documents: results = predict(query=query, documents=documents, top_k=top_k) else: - results = {"answers": []} + if hasattr(self, "return_no_answers") and self.return_no_answers: + no_ans_prediction = Answer( + answer="", + type="extractive", + score=1.0 + if hasattr(self, "use_confidence_scores") and self.use_confidence_scores + else 1024.0, # just a pseudo prob for now or old score, + context=None, + offsets_in_context=[Span(start=0, end=0)], + offsets_in_document=[Span(start=0, end=0)], + document_id=None, + meta=None, + ) + results = {"answers": [no_ans_prediction]} + else: + results = {"answers": []} # Add corresponding document_name and more meta data, if an answer contains the document_id results["answers"] = [ @@ -92,7 +117,9 @@ class BaseReader(BaseComponent): # run evaluation with labels as node inputs if add_isolated_node_eval and labels is not None: - relevant_documents = {label.document.id: label.document for label in labels.labels}.values() + relevant_documents = [label.document for label in labels.labels] + # Filter out empty documents + relevant_documents = [d for d in relevant_documents if d.content.strip() != ""] results_label_input = predict(query=query, documents=relevant_documents, top_k=top_k) # Add corresponding document_name and more meta data, if an answer contains the document_id @@ -113,6 +140,14 @@ class BaseReader(BaseComponent): add_isolated_node_eval: bool = False, ): self.query_count += len(queries) + + # Remove empty documents before making predictions + if len(documents) > 0: + if isinstance(documents[0], Document): + documents = [d for d in documents if not isinstance(d.content, str) or d.content.strip() != ""] # type: ignore[union-attr, assignment] + else: + documents = [[d for d in docs_per_query if not isinstance(d.content, str) or d.content.strip() != ""] for docs_per_query in documents] # type: ignore[union-attr] + if not documents: return {"answers": []}, "output_1" @@ -138,7 +173,11 @@ class BaseReader(BaseComponent): if add_isolated_node_eval and labels is not None: relevant_documents = [] for labelx in labels: - relevant_documents.append([label.document for label in labelx.labels]) + # Filter out empty documents + relevant_docs_labelx = [ + label.document for label in labelx.labels if label.document.content.strip() != "" + ] + relevant_documents.append(relevant_docs_labelx) results_label_input = predict_batch(queries=queries, documents=relevant_documents, top_k=top_k) # Add corresponding document_name and more meta data, if an answer contains the document_id diff --git a/haystack/nodes/reader/farm.py b/haystack/nodes/reader/farm.py index b04cf76de..b82c40205 100644 --- a/haystack/nodes/reader/farm.py +++ b/haystack/nodes/reader/farm.py @@ -832,7 +832,6 @@ class FARMReader(BaseReader): # Group predictions together grouped_predictions = [] left_idx = 0 - right_idx = 0 for number in number_of_docs: right_idx = left_idx + number grouped_predictions.append(predictions[left_idx:right_idx]) diff --git a/haystack/nodes/reader/transformers.py b/haystack/nodes/reader/transformers.py index 3de52c936..e588caaee 100644 --- a/haystack/nodes/reader/transformers.py +++ b/haystack/nodes/reader/transformers.py @@ -233,7 +233,6 @@ class TransformersReader(BaseReader): grouped_predictions = [] grouped_inputs = [] left_idx = 0 - right_idx = 0 for number in number_of_docs: right_idx = left_idx + number grouped_predictions.append(predictions[left_idx:right_idx]) @@ -247,7 +246,9 @@ class TransformersReader(BaseReader): for pred in preds_for_single_doc: cur_doc_id = inp.doc_id pred["doc_id"] = cur_doc_id - if isinstance(grouped_pred[0], list): + if len(grouped_pred) == 0: + group = [] + elif isinstance(grouped_pred[0], list): group = list(itertools.chain.from_iterable(grouped_pred)) answers, max_no_ans_gap = self._extract_answers_of_predictions(group, all_docs, top_k) results["answers"].append(answers) @@ -271,8 +272,9 @@ class TransformersReader(BaseReader): no_ans_gaps = [] best_overall_score = 0 - cur_doc_id = predictions[0]["doc_id"] - cur_doc = docs[cur_doc_id] + if len(predictions) > 0: + cur_doc_id = predictions[0]["doc_id"] + cur_doc = docs[cur_doc_id] no_ans_doc_score = 0 best_doc_score = 0 @@ -313,7 +315,9 @@ class TransformersReader(BaseReader): # + add no_ans_gap for last Document if best_doc_score > best_overall_score: best_overall_score = best_doc_score - no_ans_gaps.append(no_ans_doc_score - best_doc_score) + + if len(predictions) > 0: + no_ans_gaps.append(no_ans_doc_score - best_doc_score) # Calculate the score for predicting "no answer", relative to our best positive answer score no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps, best_overall_score) diff --git a/test/nodes/test_reader.py b/test/nodes/test_reader.py index 28dfb3d74..ca7af5f44 100644 --- a/test/nodes/test_reader.py +++ b/test/nodes/test_reader.py @@ -8,7 +8,7 @@ import pytest from huggingface_hub import snapshot_download from haystack.modeling.data_handler.inputs import QAInput, Question -from haystack.schema import Document, Answer +from haystack.schema import Document, Answer, Label, MultiLabel, Span from haystack.nodes.reader.base import BaseReader from haystack.nodes import FARMReader, TransformersReader @@ -32,6 +32,7 @@ def no_answer_reader(request): tokenizer="deepset/bert-medium-squad2-distilled", use_gpu=-1, top_k_per_candidate=5, + return_no_answers=True, ) @@ -175,7 +176,6 @@ def test_context_window_size(reader, docs, window_size): @pytest.mark.parametrize("reader", ["farm"], indirect=True) @pytest.mark.parametrize("top_k", [2, 5, 10]) def test_top_k(reader, docs, top_k): - assert isinstance(reader, FARMReader) old_top_k_per_candidate = reader.top_k_per_candidate @@ -352,3 +352,56 @@ def test_farm_reader_onnx_conversion_and_inference(model_name, tmpdir, docs): reader = FARMReader(str(Path(tmpdir, "onnx"))) result = reader.predict(query="Where does Paul live?", documents=[docs[0]]) assert result["answers"][0].answer == "New York" + + +LABELS = [ + MultiLabel( + labels=[ + Label( + query="Who lives in Berlin?", + answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]), + document=Document( + id="a0747b83aea0b60c4b114b15476dd32d", content_type="text", content="" # empty document + ), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + ) + ] + ), + MultiLabel( + labels=[ + Label( + query="Who lives in Munich?", + answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]), + document=Document( + id="something_else", content_type="text", content="My name is Carla and I live in Munich" + ), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + ) + ] + ), +] + + +def test_reader_skips_empty_documents(reader): + predictions, _ = reader.run(query=LABELS[0].labels[0].query, documents=[LABELS[0].labels[0].document]) + assert predictions["answers"] == [] # no answer given for query as document is empty + predictions, _ = reader.run_batch( + queries=[l.labels[0].query for l in LABELS], documents=[[l.labels[0].document] for l in LABELS] + ) + assert predictions["answers"][0] == [] # no answer given for 1st query as document is empty + assert predictions["answers"][1][0].answer == "Carla" # answer given for 2nd query as usual + + +@pytest.mark.parametrize("no_answer_reader", ["farm", "transformers"], indirect=True) +def test_no_answer_reader_skips_empty_documents(no_answer_reader): + predictions, _ = no_answer_reader.run(query=LABELS[0].labels[0].query, documents=[LABELS[0].labels[0].document]) + assert predictions["answers"][0].answer == "" # Return no_answer as document is empty + predictions, _ = no_answer_reader.run_batch( + queries=[l.labels[0].query for l in LABELS], documents=[[l.labels[0].document] for l in LABELS] + ) + assert predictions["answers"][0][0].answer == "" # Return no_answer for 1st query as document is empty + assert predictions["answers"][1][1].answer == "Carla" # answer given for 2nd query as usual diff --git a/test/pipelines/test_eval.py b/test/pipelines/test_eval.py index 80f93cc4b..b69b04e64 100644 --- a/test/pipelines/test_eval.py +++ b/test/pipelines/test_eval.py @@ -1411,3 +1411,28 @@ def test_multi_retriever_pipeline_with_asymmetric_qa_eval(document_store_with_do assert metrics["QAReader"]["exact_match"] == 1.0 assert metrics["QAReader"]["f1"] == 1.0 + + +@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) +@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) +@pytest.mark.parametrize("reader", ["farm", "transformers"], indirect=True) +def test_empty_documents_dont_fail_pipeline(reader, retriever_with_docs): + multilabels = EVAL_LABELS[:2] + multilabels[0].labels[0].document.content = "" + pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs) + eval_result_integrated: EvaluationResult = pipeline.eval(labels=multilabels, add_isolated_node_eval=False) + assert eval_result_integrated["Reader"]["answer"].iloc[0] == "Carla" + eval_result_iso: EvaluationResult = pipeline.eval(labels=multilabels, add_isolated_node_eval=True) + assert eval_result_iso["Reader"].loc[eval_result_iso["Reader"]["eval_mode"] == "isolated"]["answer"].iloc[0] == "" + + eval_batch_result_integrated: EvaluationResult = pipeline.eval_batch( + labels=multilabels, add_isolated_node_eval=False + ) + assert eval_batch_result_integrated["Reader"]["answer"].iloc[0] == "Carla" + eval_batch_result_iso: EvaluationResult = pipeline.eval_batch(labels=multilabels, add_isolated_node_eval=True) + assert ( + eval_batch_result_iso["Reader"] + .loc[eval_batch_result_iso["Reader"]["eval_mode"] == "isolated"]["answer"] + .iloc[0] + == "" + )