bug: skip empty documents in reader (#3773)

* skip empty documents

* test eval_batch and account for tables
This commit is contained in:
Julian Risch 2023-01-03 15:50:14 +01:00 committed by GitHub
parent 43328d2744
commit a2c160e7d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 142 additions and 18 deletions

View File

@ -508,6 +508,10 @@ class QAInferencer(Inferencer):
This parameter has no effect; it will be removed as Inferencer multiprocessing
has been deprecated.
"""
# Return no predictions if there are no inputs
if not objects:
return []
dicts = [o.to_dict() for o in objects]
# TODO investigate this deprecation warning. Timo: I thought we were about to implement Input Objects,
# then we can and should use inference from (input) objects!

View File

@ -40,17 +40,25 @@ class BaseReader(BaseComponent):
# the most significant difference between scores.
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
# No_ans_gap is a list of this most significant difference per document
no_ans_gap_array = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gap_array)
# case 1: all passages "no answer" as top score
# max_no_ans_gap is negative, so it increases best pos score
# case 2: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap
# If there is not even one predicted answer, we return a no_answer with score 1.0
if best_score_answer == 0 and len(no_ans_gaps) == 0:
no_ans_score = 1024.0
no_ans_score_scaled = 1.0
max_no_ans_gap = 1024.0
else:
no_ans_gap_array = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gap_array)
# case 1: all passages "no answer" as top score
# max_no_ans_gap is negative, so it increases best pos score
# case 2: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap
no_ans_score_scaled = float(expit(np.asarray(no_ans_score) / 8))
no_ans_prediction = Answer(
answer="",
type="extractive",
score=float(expit(np.asarray(no_ans_score) / 8))
score=no_ans_score_scaled
if use_confidence_scores
else no_ans_score, # just a pseudo prob for now or old score,
context=None,
@ -80,10 +88,27 @@ class BaseReader(BaseComponent):
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, labels: Optional[MultiLabel] = None, add_isolated_node_eval: bool = False): # type: ignore
self.query_count += 1
predict = self.timing(self.predict, "query_time")
# Remove empty text documents before making predictions
documents = [d for d in documents if not isinstance(d.content, str) or d.content.strip() != ""]
if documents:
results = predict(query=query, documents=documents, top_k=top_k)
else:
results = {"answers": []}
if hasattr(self, "return_no_answers") and self.return_no_answers:
no_ans_prediction = Answer(
answer="",
type="extractive",
score=1.0
if hasattr(self, "use_confidence_scores") and self.use_confidence_scores
else 1024.0, # just a pseudo prob for now or old score,
context=None,
offsets_in_context=[Span(start=0, end=0)],
offsets_in_document=[Span(start=0, end=0)],
document_id=None,
meta=None,
)
results = {"answers": [no_ans_prediction]}
else:
results = {"answers": []}
# Add corresponding document_name and more meta data, if an answer contains the document_id
results["answers"] = [
@ -92,7 +117,9 @@ class BaseReader(BaseComponent):
# run evaluation with labels as node inputs
if add_isolated_node_eval and labels is not None:
relevant_documents = {label.document.id: label.document for label in labels.labels}.values()
relevant_documents = [label.document for label in labels.labels]
# Filter out empty documents
relevant_documents = [d for d in relevant_documents if d.content.strip() != ""]
results_label_input = predict(query=query, documents=relevant_documents, top_k=top_k)
# Add corresponding document_name and more meta data, if an answer contains the document_id
@ -113,6 +140,14 @@ class BaseReader(BaseComponent):
add_isolated_node_eval: bool = False,
):
self.query_count += len(queries)
# Remove empty documents before making predictions
if len(documents) > 0:
if isinstance(documents[0], Document):
documents = [d for d in documents if not isinstance(d.content, str) or d.content.strip() != ""] # type: ignore[union-attr, assignment]
else:
documents = [[d for d in docs_per_query if not isinstance(d.content, str) or d.content.strip() != ""] for docs_per_query in documents] # type: ignore[union-attr]
if not documents:
return {"answers": []}, "output_1"
@ -138,7 +173,11 @@ class BaseReader(BaseComponent):
if add_isolated_node_eval and labels is not None:
relevant_documents = []
for labelx in labels:
relevant_documents.append([label.document for label in labelx.labels])
# Filter out empty documents
relevant_docs_labelx = [
label.document for label in labelx.labels if label.document.content.strip() != ""
]
relevant_documents.append(relevant_docs_labelx)
results_label_input = predict_batch(queries=queries, documents=relevant_documents, top_k=top_k)
# Add corresponding document_name and more meta data, if an answer contains the document_id

View File

@ -832,7 +832,6 @@ class FARMReader(BaseReader):
# Group predictions together
grouped_predictions = []
left_idx = 0
right_idx = 0
for number in number_of_docs:
right_idx = left_idx + number
grouped_predictions.append(predictions[left_idx:right_idx])

View File

@ -233,7 +233,6 @@ class TransformersReader(BaseReader):
grouped_predictions = []
grouped_inputs = []
left_idx = 0
right_idx = 0
for number in number_of_docs:
right_idx = left_idx + number
grouped_predictions.append(predictions[left_idx:right_idx])
@ -247,7 +246,9 @@ class TransformersReader(BaseReader):
for pred in preds_for_single_doc:
cur_doc_id = inp.doc_id
pred["doc_id"] = cur_doc_id
if isinstance(grouped_pred[0], list):
if len(grouped_pred) == 0:
group = []
elif isinstance(grouped_pred[0], list):
group = list(itertools.chain.from_iterable(grouped_pred))
answers, max_no_ans_gap = self._extract_answers_of_predictions(group, all_docs, top_k)
results["answers"].append(answers)
@ -271,8 +272,9 @@ class TransformersReader(BaseReader):
no_ans_gaps = []
best_overall_score = 0
cur_doc_id = predictions[0]["doc_id"]
cur_doc = docs[cur_doc_id]
if len(predictions) > 0:
cur_doc_id = predictions[0]["doc_id"]
cur_doc = docs[cur_doc_id]
no_ans_doc_score = 0
best_doc_score = 0
@ -313,7 +315,9 @@ class TransformersReader(BaseReader):
# + add no_ans_gap for last Document
if best_doc_score > best_overall_score:
best_overall_score = best_doc_score
no_ans_gaps.append(no_ans_doc_score - best_doc_score)
if len(predictions) > 0:
no_ans_gaps.append(no_ans_doc_score - best_doc_score)
# Calculate the score for predicting "no answer", relative to our best positive answer score
no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps, best_overall_score)

View File

@ -8,7 +8,7 @@ import pytest
from huggingface_hub import snapshot_download
from haystack.modeling.data_handler.inputs import QAInput, Question
from haystack.schema import Document, Answer
from haystack.schema import Document, Answer, Label, MultiLabel, Span
from haystack.nodes.reader.base import BaseReader
from haystack.nodes import FARMReader, TransformersReader
@ -32,6 +32,7 @@ def no_answer_reader(request):
tokenizer="deepset/bert-medium-squad2-distilled",
use_gpu=-1,
top_k_per_candidate=5,
return_no_answers=True,
)
@ -175,7 +176,6 @@ def test_context_window_size(reader, docs, window_size):
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
@pytest.mark.parametrize("top_k", [2, 5, 10])
def test_top_k(reader, docs, top_k):
assert isinstance(reader, FARMReader)
old_top_k_per_candidate = reader.top_k_per_candidate
@ -352,3 +352,56 @@ def test_farm_reader_onnx_conversion_and_inference(model_name, tmpdir, docs):
reader = FARMReader(str(Path(tmpdir, "onnx")))
result = reader.predict(query="Where does Paul live?", documents=[docs[0]])
assert result["answers"][0].answer == "New York"
LABELS = [
MultiLabel(
labels=[
Label(
query="Who lives in Berlin?",
answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
document=Document(
id="a0747b83aea0b60c4b114b15476dd32d", content_type="text", content="" # empty document
),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
)
]
),
MultiLabel(
labels=[
Label(
query="Who lives in Munich?",
answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
document=Document(
id="something_else", content_type="text", content="My name is Carla and I live in Munich"
),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
)
]
),
]
def test_reader_skips_empty_documents(reader):
predictions, _ = reader.run(query=LABELS[0].labels[0].query, documents=[LABELS[0].labels[0].document])
assert predictions["answers"] == [] # no answer given for query as document is empty
predictions, _ = reader.run_batch(
queries=[l.labels[0].query for l in LABELS], documents=[[l.labels[0].document] for l in LABELS]
)
assert predictions["answers"][0] == [] # no answer given for 1st query as document is empty
assert predictions["answers"][1][0].answer == "Carla" # answer given for 2nd query as usual
@pytest.mark.parametrize("no_answer_reader", ["farm", "transformers"], indirect=True)
def test_no_answer_reader_skips_empty_documents(no_answer_reader):
predictions, _ = no_answer_reader.run(query=LABELS[0].labels[0].query, documents=[LABELS[0].labels[0].document])
assert predictions["answers"][0].answer == "" # Return no_answer as document is empty
predictions, _ = no_answer_reader.run_batch(
queries=[l.labels[0].query for l in LABELS], documents=[[l.labels[0].document] for l in LABELS]
)
assert predictions["answers"][0][0].answer == "" # Return no_answer for 1st query as document is empty
assert predictions["answers"][1][1].answer == "Carla" # answer given for 2nd query as usual

View File

@ -1411,3 +1411,28 @@ def test_multi_retriever_pipeline_with_asymmetric_qa_eval(document_store_with_do
assert metrics["QAReader"]["exact_match"] == 1.0
assert metrics["QAReader"]["f1"] == 1.0
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("reader", ["farm", "transformers"], indirect=True)
def test_empty_documents_dont_fail_pipeline(reader, retriever_with_docs):
multilabels = EVAL_LABELS[:2]
multilabels[0].labels[0].document.content = ""
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
eval_result_integrated: EvaluationResult = pipeline.eval(labels=multilabels, add_isolated_node_eval=False)
assert eval_result_integrated["Reader"]["answer"].iloc[0] == "Carla"
eval_result_iso: EvaluationResult = pipeline.eval(labels=multilabels, add_isolated_node_eval=True)
assert eval_result_iso["Reader"].loc[eval_result_iso["Reader"]["eval_mode"] == "isolated"]["answer"].iloc[0] == ""
eval_batch_result_integrated: EvaluationResult = pipeline.eval_batch(
labels=multilabels, add_isolated_node_eval=False
)
assert eval_batch_result_integrated["Reader"]["answer"].iloc[0] == "Carla"
eval_batch_result_iso: EvaluationResult = pipeline.eval_batch(labels=multilabels, add_isolated_node_eval=True)
assert (
eval_batch_result_iso["Reader"]
.loc[eval_batch_result_iso["Reader"]["eval_mode"] == "isolated"]["answer"]
.iloc[0]
== ""
)