mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 12:07:04 +00:00
Batch prediction in evaluation (#137)
* Add Batch evaluation * Separate evaluation methods * Clean calculation of eval metrics * Adapt eval to Label objects * Fix format of no_answer * Adapt to MultiLabel * Add tests
This commit is contained in:
parent
860f860b00
commit
5186d2d235
@ -90,7 +90,7 @@ class Label:
|
||||
or, user-feedback from the Haystack REST API.
|
||||
|
||||
:param question: the question(or query) for finding answers.
|
||||
:param answer: teh answer string.
|
||||
:param answer: the answer string.
|
||||
:param is_correct_answer: whether the sample is positive or negative.
|
||||
:param is_correct_document: in case of negative sample(is_correct_answer is False), there could be two cases;
|
||||
incorrect answer but correct document & incorrect document. This flag denotes if
|
||||
@ -263,9 +263,8 @@ class BaseDocumentStore(ABC):
|
||||
is_correct_answer=l.is_correct_answer,
|
||||
is_correct_document=l.is_correct_document,
|
||||
origin=l.origin,
|
||||
multiple_document_ids=[l.document_id] if l.document_id else [],
|
||||
multiple_offset_start_in_docs=[
|
||||
l.offset_start_in_doc] if l.offset_start_in_doc else [],
|
||||
multiple_document_ids=[l.document_id],
|
||||
multiple_offset_start_in_docs=[l.offset_start_in_doc],
|
||||
no_answer=l.no_answer,
|
||||
model_id=l.model_id,
|
||||
)
|
||||
|
||||
248
haystack/eval.py
Normal file
248
haystack/eval.py
Normal file
@ -0,0 +1,248 @@
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
from haystack.database.base import MultiLabel
|
||||
|
||||
|
||||
def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals: int):
|
||||
number_of_has_answer = correct_retrievals - metric_counts["number_of_no_answer"]
|
||||
|
||||
metrics = {
|
||||
"reader_top1_accuracy" : metric_counts["correct_readings_top1"] / correct_retrievals,
|
||||
"reader_top1_accuracy_has_answer" : metric_counts["correct_readings_top1_has_answer"] / number_of_has_answer,
|
||||
"reader_topk_accuracy" : metric_counts["correct_readings_topk"] / correct_retrievals,
|
||||
"reader_topk_accuracy_has_answer" : metric_counts["correct_readings_topk_has_answer"] / number_of_has_answer,
|
||||
"reader_top1_em" : metric_counts["exact_matches_top1"] / correct_retrievals,
|
||||
"reader_top1_em_has_answer" : metric_counts["exact_matches_top1_has_answer"] / number_of_has_answer,
|
||||
"reader_topk_em" : metric_counts["exact_matches_topk"] / correct_retrievals,
|
||||
"reader_topk_em_has_answer" : metric_counts["exact_matches_topk_has_answer"] / number_of_has_answer,
|
||||
"reader_top1_f1" : metric_counts["summed_f1_top1"] / correct_retrievals,
|
||||
"reader_top1_f1_has_answer" : metric_counts["summed_f1_top1_has_answer"] / number_of_has_answer,
|
||||
"reader_topk_f1" : metric_counts["summed_f1_topk"] / correct_retrievals,
|
||||
"reader_topk_f1_has_answer" : metric_counts["summed_f1_topk_has_answer"] / number_of_has_answer,
|
||||
}
|
||||
|
||||
if metric_counts["number_of_no_answer"]:
|
||||
metrics["reader_top1_no_answer_accuracy"] = metric_counts["correct_no_answers_top1"] / metric_counts[
|
||||
"number_of_no_answer"]
|
||||
metrics["reader_topk_no_answer_accuracy"] = metric_counts["correct_no_answers_topk"] / metric_counts[
|
||||
"number_of_no_answer"]
|
||||
else:
|
||||
metrics["reader_top1_no_answer_accuracy"] = None # type: ignore
|
||||
metrics["reader_topk_no_answer_accuracy"] = None # type: ignore
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def calculate_average_precision(questions_with_docs: List[dict]):
|
||||
questions_with_correct_doc = []
|
||||
summed_avg_precision_retriever = 0.0
|
||||
|
||||
for question in questions_with_docs:
|
||||
for doc_idx, doc in enumerate(question["docs"]):
|
||||
# check if correct doc among retrieved docs
|
||||
if doc.id in question["question"].multiple_document_ids:
|
||||
summed_avg_precision_retriever += 1 / (doc_idx + 1)
|
||||
questions_with_correct_doc.append({
|
||||
"question": question["question"],
|
||||
"docs": question["docs"]
|
||||
})
|
||||
break
|
||||
|
||||
return questions_with_correct_doc, summed_avg_precision_retriever
|
||||
|
||||
|
||||
def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]):
|
||||
# Calculates evaluation metrics for one question and adds results to counter.
|
||||
# check if question is answerable
|
||||
if not question.no_answer:
|
||||
found_answer = False
|
||||
found_em = False
|
||||
best_f1 = 0
|
||||
for answer_idx, answer in enumerate(predicted_answers["answers"]):
|
||||
if answer["document_id"] in question.multiple_document_ids:
|
||||
gold_spans = [{"offset_start": question.multiple_offset_start_in_docs[i],
|
||||
"offset_end": question.multiple_offset_start_in_docs[i] + len(question.multiple_answers[i]),
|
||||
"doc_id": question.multiple_document_ids[i]} for i in range(len(question.multiple_answers))] # type: ignore
|
||||
predicted_span = {"offset_start": answer["offset_start_in_doc"],
|
||||
"offset_end": answer["offset_end_in_doc"],
|
||||
"doc_id": answer["document_id"]}
|
||||
best_f1_in_gold_spans = 0
|
||||
for gold_span in gold_spans:
|
||||
if gold_span["doc_id"] == predicted_span["doc_id"]:
|
||||
# check if overlap between gold answer and predicted answer
|
||||
if not found_answer:
|
||||
metric_counts, found_answer = _count_overlap(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore
|
||||
|
||||
# check for exact match
|
||||
if not found_em:
|
||||
metric_counts, found_em = _count_exact_match(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore
|
||||
|
||||
# calculate f1
|
||||
current_f1 = _calculate_f1(gold_span, predicted_span) # type: ignore
|
||||
if current_f1 > best_f1_in_gold_spans:
|
||||
best_f1_in_gold_spans = current_f1
|
||||
# top-1 f1
|
||||
if answer_idx == 0:
|
||||
metric_counts["summed_f1_top1"] += best_f1_in_gold_spans
|
||||
metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans
|
||||
if best_f1_in_gold_spans > best_f1:
|
||||
best_f1 = best_f1_in_gold_spans
|
||||
|
||||
if found_em:
|
||||
break
|
||||
# top-k answers: use best f1-score
|
||||
metric_counts["summed_f1_topk"] += best_f1
|
||||
metric_counts["summed_f1_topk_has_answer"] += best_f1
|
||||
|
||||
# question not answerable
|
||||
else:
|
||||
metric_counts["number_of_no_answer"] += 1
|
||||
metric_counts = _count_no_answer(predicted_answers["answers"], metric_counts)
|
||||
|
||||
return metric_counts
|
||||
|
||||
|
||||
def eval_counts_reader_batch(pred: Dict[str, Any], metric_counts: Dict[str, float]):
|
||||
# Calculates evaluation metrics for one question and adds results to counter.
|
||||
|
||||
# check if question is answerable
|
||||
if not pred["label"].no_answer:
|
||||
found_answer = False
|
||||
found_em = False
|
||||
best_f1 = 0
|
||||
for answer_idx, answer in enumerate(pred["answers"]):
|
||||
# check if correct document:
|
||||
if answer["document_id"] in pred["label"].multiple_document_ids:
|
||||
gold_spans = [{"offset_start": pred["label"].multiple_offset_start_in_docs[i],
|
||||
"offset_end": pred["label"].multiple_offset_start_in_docs[i] + len(pred["label"].multiple_answers[i]),
|
||||
"doc_id": pred["label"].multiple_document_ids[i]}
|
||||
for i in range(len(pred["label"].multiple_answers))] # type: ignore
|
||||
predicted_span = {"offset_start": answer["offset_start_in_doc"],
|
||||
"offset_end": answer["offset_end_in_doc"],
|
||||
"doc_id": answer["document_id"]}
|
||||
|
||||
best_f1_in_gold_spans = 0
|
||||
for gold_span in gold_spans:
|
||||
if gold_span["doc_id"] == predicted_span["doc_id"]:
|
||||
# check if overlap between gold answer and predicted answer
|
||||
if not found_answer:
|
||||
metric_counts, found_answer = _count_overlap(
|
||||
gold_span, predicted_span, metric_counts, answer_idx
|
||||
)
|
||||
# check for exact match
|
||||
if not found_em:
|
||||
metric_counts, found_em = _count_exact_match(
|
||||
gold_span, predicted_span, metric_counts, answer_idx
|
||||
)
|
||||
# calculate f1
|
||||
current_f1 = _calculate_f1(gold_span, predicted_span)
|
||||
if current_f1 > best_f1_in_gold_spans:
|
||||
best_f1_in_gold_spans = current_f1
|
||||
# top-1 f1
|
||||
if answer_idx == 0:
|
||||
metric_counts["summed_f1_top1"] += best_f1_in_gold_spans
|
||||
metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans
|
||||
if best_f1_in_gold_spans > best_f1:
|
||||
best_f1 = best_f1_in_gold_spans
|
||||
|
||||
if found_em:
|
||||
break
|
||||
|
||||
# top-k answers: use best f1-score
|
||||
metric_counts["summed_f1_topk"] += best_f1
|
||||
metric_counts["summed_f1_topk_has_answer"] += best_f1
|
||||
|
||||
# question not answerable
|
||||
else:
|
||||
metric_counts["number_of_no_answer"] += 1
|
||||
metric_counts = _count_no_answer(pred["answers"], metric_counts)
|
||||
|
||||
return metric_counts
|
||||
|
||||
|
||||
def _count_overlap(
|
||||
gold_span: Dict[str, Any],
|
||||
predicted_span: Dict[str, Any],
|
||||
metric_counts: Dict[str, float],
|
||||
answer_idx: int
|
||||
):
|
||||
# Checks if overlap between prediction and real answer.
|
||||
|
||||
found_answer = False
|
||||
|
||||
if (gold_span["offset_start"] <= predicted_span["offset_end"]) and \
|
||||
(predicted_span["offset_start"] <= gold_span["offset_end"]):
|
||||
# top-1 answer
|
||||
if answer_idx == 0:
|
||||
metric_counts["correct_readings_top1"] += 1
|
||||
metric_counts["correct_readings_top1_has_answer"] += 1
|
||||
# top-k answers
|
||||
metric_counts["correct_readings_topk"] += 1
|
||||
metric_counts["correct_readings_topk_has_answer"] += 1
|
||||
found_answer = True
|
||||
|
||||
return metric_counts, found_answer
|
||||
|
||||
|
||||
def _count_exact_match(
|
||||
gold_span: Dict[str, Any],
|
||||
predicted_span: Dict[str, Any],
|
||||
metric_counts: Dict[str, float],
|
||||
answer_idx: int
|
||||
):
|
||||
# Check if exact match between prediction and real answer.
|
||||
# As evaluation needs to be framework independent, we cannot use the farm.evaluation.metrics.py functions.
|
||||
|
||||
found_em = False
|
||||
|
||||
if (gold_span["offset_start"] == predicted_span["offset_start"]) and \
|
||||
(gold_span["offset_end"] == predicted_span["offset_end"]):
|
||||
# top-1 answer
|
||||
if answer_idx == 0:
|
||||
metric_counts["exact_matches_top1"] += 1
|
||||
metric_counts["exact_matches_top1_has_answer"] += 1
|
||||
# top-k answers
|
||||
metric_counts["exact_matches_topk"] += 1
|
||||
metric_counts["exact_matches_topk_has_answer"] += 1
|
||||
found_em = True
|
||||
|
||||
return metric_counts, found_em
|
||||
|
||||
|
||||
def _calculate_f1(gold_span: Dict[str, Any], predicted_span: Dict[str, Any]):
|
||||
# Calculates F1-Score for prediction based on real answer using character offsets.
|
||||
# As evaluation needs to be framework independent, we cannot use the farm.evaluation.metrics.py functions.
|
||||
|
||||
pred_indices = list(range(predicted_span["offset_start"], predicted_span["offset_end"]))
|
||||
gold_indices = list(range(gold_span["offset_start"], gold_span["offset_end"]))
|
||||
n_overlap = len([x for x in pred_indices if x in gold_indices])
|
||||
if pred_indices and gold_indices and n_overlap:
|
||||
precision = n_overlap / len(pred_indices)
|
||||
recall = n_overlap / len(gold_indices)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
|
||||
return f1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def _count_no_answer(answers: List[dict], metric_counts: Dict[str, float]):
|
||||
# Checks if one of the answers is 'no answer'.
|
||||
|
||||
for answer_idx, answer in enumerate(answers):
|
||||
# check if 'no answer'
|
||||
if answer["answer"] is None:
|
||||
# top-1 answer
|
||||
if answer_idx == 0:
|
||||
metric_counts["correct_no_answers_top1"] += 1
|
||||
metric_counts["correct_readings_top1"] += 1
|
||||
metric_counts["exact_matches_top1"] += 1
|
||||
metric_counts["summed_f1_top1"] += 1
|
||||
# top-k answers
|
||||
metric_counts["correct_no_answers_topk"] += 1
|
||||
metric_counts["correct_readings_topk"] += 1
|
||||
metric_counts["exact_matches_topk"] += 1
|
||||
metric_counts["summed_f1_topk"] += 1
|
||||
break
|
||||
|
||||
return metric_counts
|
||||
@ -1,13 +1,17 @@
|
||||
import logging
|
||||
import time
|
||||
from statistics import mean
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from scipy.special import expit
|
||||
|
||||
from haystack.reader.base import BaseReader
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
from haystack.database.base import MultiLabel, Document
|
||||
from haystack.eval import calculate_average_precision, eval_counts_reader_batch, calculate_reader_metrics, \
|
||||
eval_counts_reader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -110,8 +114,8 @@ class Finder:
|
||||
|
||||
def eval(
|
||||
self,
|
||||
label_index: str = "label",
|
||||
doc_index: str = "eval_document",
|
||||
label_index: str,
|
||||
doc_index: str,
|
||||
label_origin: str = "gold_label",
|
||||
top_k_retriever: int = 10,
|
||||
top_k_reader: int = 10,
|
||||
@ -161,8 +165,6 @@ class Finder:
|
||||
:param top_k_reader: How many answers to return per question
|
||||
:type top_k_reader: int
|
||||
"""
|
||||
raise NotImplementedError("The Finder evaluation is unavailable in the current Haystack version due to code "
|
||||
"refactoring in-progress. Please use Reader and Retriever evaluation.")
|
||||
|
||||
if not self.reader or not self.retriever:
|
||||
raise Exception("Finder needs to have a reader and retriever for the evaluation.")
|
||||
@ -170,192 +172,194 @@ class Finder:
|
||||
finder_start_time = time.time()
|
||||
# extract all questions for evaluation
|
||||
filters = {"origin": [label_origin]}
|
||||
questions = self.retriever.document_store.get_all_documents_in_index(index=label_index, filters=filters) # type: ignore
|
||||
questions = self.retriever.document_store.get_all_labels_aggregated(index=label_index, filters=filters)
|
||||
|
||||
correct_retrievals = 0
|
||||
summed_avg_precision_retriever = 0
|
||||
counts = defaultdict(float) # type: Dict[str, float]
|
||||
retrieve_times = []
|
||||
|
||||
correct_readings_top1 = 0
|
||||
correct_readings_topk = 0
|
||||
correct_readings_top1_has_answer = 0
|
||||
correct_readings_topk_has_answer = 0
|
||||
exact_matches_top1 = 0
|
||||
exact_matches_topk = 0
|
||||
exact_matches_top1_has_answer = 0
|
||||
exact_matches_topk_has_answer = 0
|
||||
summed_f1_top1 = 0
|
||||
summed_f1_topk = 0
|
||||
summed_f1_top1_has_answer = 0
|
||||
summed_f1_topk_has_answer = 0
|
||||
correct_no_answers_top1 = 0
|
||||
correct_no_answers_topk = 0
|
||||
read_times = []
|
||||
|
||||
# retrieve documents
|
||||
questions_with_docs = []
|
||||
retriever_start_time = time.time()
|
||||
for q_idx, question in enumerate(questions):
|
||||
question_string = question["_source"]["question"]
|
||||
question_string = question.question
|
||||
single_retrieve_start = time.time()
|
||||
retrieved_docs = self.retriever.retrieve(question_string, top_k=top_k_retriever, index=doc_index)
|
||||
retrieve_times.append(time.time() - single_retrieve_start)
|
||||
|
||||
# check if correct doc among retrieved docs
|
||||
for doc_idx, doc in enumerate(retrieved_docs):
|
||||
# check if correct doc among retrieved docs
|
||||
if doc.meta["doc_id"] == question["_source"]["doc_id"]: # type: ignore
|
||||
correct_retrievals += 1
|
||||
summed_avg_precision_retriever += 1 / (doc_idx + 1) # type: ignore
|
||||
if doc.id in question.multiple_document_ids:
|
||||
counts["correct_retrievals"] += 1
|
||||
counts["summed_avg_precision_retriever"] += 1 / (doc_idx + 1)
|
||||
questions_with_docs.append({
|
||||
"question": question,
|
||||
"docs": retrieved_docs,
|
||||
"correct_es_doc_id": doc.id})
|
||||
"docs": retrieved_docs
|
||||
})
|
||||
break
|
||||
retriever_total_time = time.time() - retriever_start_time
|
||||
number_of_questions = q_idx + 1
|
||||
|
||||
number_of_no_answer = 0
|
||||
previous_return_no_answers = self.reader.return_no_answers # type: ignore
|
||||
self.reader.return_no_answers = True # type: ignore
|
||||
retriever_total_time = time.time() - retriever_start_time
|
||||
counts["number_of_questions"] = q_idx + 1
|
||||
|
||||
previous_return_no_answers = self.reader.return_no_answers
|
||||
self.reader.return_no_answers = True
|
||||
|
||||
# extract answers
|
||||
reader_start_time = time.time()
|
||||
for q_idx, question in enumerate(questions_with_docs):
|
||||
for q_idx, question_docs in enumerate(questions_with_docs):
|
||||
if (q_idx + 1) % 100 == 0:
|
||||
print(f"Processed {q_idx+1} questions.")
|
||||
question_string = question["question"]["_source"]["question"]
|
||||
docs = question["docs"]
|
||||
|
||||
question = question_docs["question"] # type: ignore
|
||||
question_string = question.question
|
||||
docs = question_docs["docs"] # type: ignore
|
||||
single_reader_start = time.time()
|
||||
predicted_answers = self.reader.predict(question_string, docs, top_k_reader)
|
||||
predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore
|
||||
read_times.append(time.time() - single_reader_start)
|
||||
# check if question is answerable
|
||||
if question["question"]["_source"]["answers"]:
|
||||
for answer_idx, answer in enumerate(predicted_answers["answers"]):
|
||||
found_answer = False
|
||||
found_em = False
|
||||
best_f1 = 0
|
||||
# check if correct document
|
||||
if answer["document_id"] == question["correct_es_doc_id"]:
|
||||
gold_spans = [(gold_answer["answer_start"], gold_answer["answer_start"] + len(gold_answer["text"]) + 1)
|
||||
for gold_answer in question["question"]["_source"]["answers"]]
|
||||
predicted_span = (answer["offset_start_in_doc"], answer["offset_end_in_doc"])
|
||||
counts = eval_counts_reader(question, predicted_answers, counts)
|
||||
|
||||
for gold_span in gold_spans:
|
||||
# check if overlap between gold answer and predicted answer
|
||||
# top-1 answer
|
||||
if not found_answer:
|
||||
if (gold_span[0] <= predicted_span[1]) and (predicted_span[0] <= gold_span[1]):
|
||||
# top-1 answer
|
||||
if answer_idx == 0:
|
||||
correct_readings_top1 += 1
|
||||
correct_readings_top1_has_answer += 1
|
||||
# top-k answers
|
||||
correct_readings_topk += 1
|
||||
correct_readings_topk_has_answer += 1
|
||||
found_answer = True
|
||||
# check for exact match
|
||||
if not found_em:
|
||||
if (gold_span[0] == predicted_span[0]) and (gold_span[1] == predicted_span[1]):
|
||||
# top-1-answer
|
||||
if answer_idx == 0:
|
||||
exact_matches_top1 += 1
|
||||
exact_matches_top1_has_answer += 1
|
||||
# top-k answers
|
||||
exact_matches_topk += 1
|
||||
exact_matches_topk_has_answer += 1
|
||||
found_em = True
|
||||
# calculate f1
|
||||
pred_indices = list(range(predicted_span[0], predicted_span[1] + 1))
|
||||
gold_indices = list(range(gold_span[0], gold_span[1] + 1))
|
||||
n_overlap = len([x for x in pred_indices if x in gold_indices])
|
||||
if pred_indices and gold_indices and n_overlap:
|
||||
precision = n_overlap / len(pred_indices)
|
||||
recall = n_overlap / len(gold_indices)
|
||||
current_f1 = (2 * precision * recall) / (precision + recall)
|
||||
# top-1 answer
|
||||
if answer_idx == 0:
|
||||
summed_f1_top1 += current_f1 # type: ignore
|
||||
summed_f1_top1_has_answer += current_f1 # type: ignore
|
||||
if current_f1 > best_f1:
|
||||
best_f1 = current_f1 # type: ignore
|
||||
# top-k answers: use best f1-score
|
||||
summed_f1_topk += best_f1
|
||||
summed_f1_topk_has_answer += best_f1
|
||||
|
||||
if found_answer and found_em:
|
||||
break
|
||||
# question not answerable
|
||||
else:
|
||||
number_of_no_answer += 1
|
||||
# As question is not answerable, it is not clear how to compute average precision for this question.
|
||||
# For now, we decided to calculate average precision based on the rank of 'no answer'.
|
||||
for answer_idx, answer in enumerate(predicted_answers["answers"]):
|
||||
# check if 'no answer'
|
||||
if answer["answer"] is None:
|
||||
if answer_idx == 0:
|
||||
correct_no_answers_top1 += 1
|
||||
correct_readings_top1 += 1
|
||||
exact_matches_top1 += 1
|
||||
summed_f1_top1 += 1
|
||||
correct_no_answers_topk += 1
|
||||
correct_readings_topk += 1
|
||||
exact_matches_topk += 1
|
||||
summed_f1_topk += 1
|
||||
break
|
||||
number_of_has_answer = correct_retrievals - number_of_no_answer
|
||||
counts["number_of_has_answer"] = counts["correct_retrievals"] - counts["number_of_no_answer"]
|
||||
|
||||
reader_total_time = time.time() - reader_start_time
|
||||
finder_total_time = time.time() - finder_start_time
|
||||
|
||||
retriever_recall = correct_retrievals / number_of_questions
|
||||
retriever_map = summed_avg_precision_retriever / number_of_questions
|
||||
|
||||
reader_top1_accuracy = correct_readings_top1 / correct_retrievals
|
||||
reader_top1_accuracy_has_answer = correct_readings_top1_has_answer / number_of_has_answer
|
||||
reader_top_k_accuracy = correct_readings_topk / correct_retrievals
|
||||
reader_topk_accuracy_has_answer = correct_readings_topk_has_answer / number_of_has_answer
|
||||
reader_top1_em = exact_matches_top1 / correct_retrievals
|
||||
reader_top1_em_has_answer = exact_matches_top1_has_answer / number_of_has_answer
|
||||
reader_topk_em = exact_matches_topk / correct_retrievals
|
||||
reader_topk_em_has_answer = exact_matches_topk_has_answer / number_of_has_answer
|
||||
reader_top1_f1 = summed_f1_top1 / correct_retrievals
|
||||
reader_top1_f1_has_answer = summed_f1_top1_has_answer / number_of_has_answer
|
||||
reader_topk_f1 = summed_f1_topk / correct_retrievals
|
||||
reader_topk_f1_has_answer = summed_f1_topk_has_answer / number_of_has_answer
|
||||
reader_top1_no_answer_accuracy = correct_no_answers_top1 / number_of_no_answer
|
||||
reader_topk_no_answer_accuracy = correct_no_answers_topk / number_of_no_answer
|
||||
|
||||
self.reader.return_no_answers = previous_return_no_answers # type: ignore
|
||||
|
||||
logger.info((f"{correct_readings_topk} out of {number_of_questions} questions were correctly answered "
|
||||
f"({(correct_readings_topk/number_of_questions):.2%})."))
|
||||
logger.info(f"{number_of_questions-correct_retrievals} questions could not be answered due to the retriever.")
|
||||
logger.info(f"{correct_retrievals-correct_readings_topk} questions could not be answered due to the reader.")
|
||||
logger.info((f"{counts['correct_readings_topk']} out of {counts['number_of_questions']} questions were correctly"
|
||||
f" answered {(counts['correct_readings_topk']/counts['number_of_questions']):.2%})."))
|
||||
logger.info((f"{counts['number_of_questions']-counts['correct_retrievals']} questions could not be answered due "
|
||||
f"to the retriever."))
|
||||
logger.info((f"{counts['correct_retrievals']-counts['correct_readings_topk']} questions could not be answered "
|
||||
f"due to the reader."))
|
||||
|
||||
results = {
|
||||
"retriever_recall": retriever_recall,
|
||||
"retriever_map": retriever_map,
|
||||
"reader_top1_accuracy": reader_top1_accuracy,
|
||||
"reader_top1_accuracy_has_answer": reader_top1_accuracy_has_answer,
|
||||
"reader_top_k_accuracy": reader_top_k_accuracy,
|
||||
"reader_topk_accuracy_has_answer": reader_topk_accuracy_has_answer,
|
||||
"reader_top1_em": reader_top1_em,
|
||||
"reader_top1_em_has_answer": reader_top1_em_has_answer,
|
||||
"reader_topk_em": reader_topk_em,
|
||||
"reader_topk_em_has_answer": reader_topk_em_has_answer,
|
||||
"reader_top1_f1": reader_top1_f1,
|
||||
"reader_top1_f1_has_answer": reader_top1_f1_has_answer,
|
||||
"reader_topk_f1": reader_topk_f1,
|
||||
"reader_topk_f1_has_answer": reader_topk_f1_has_answer,
|
||||
"reader_top1_no_answer_accuracy": reader_top1_no_answer_accuracy,
|
||||
"reader_topk_no_answer_accuracy": reader_topk_no_answer_accuracy,
|
||||
"total_retrieve_time": retriever_total_time,
|
||||
"avg_retrieve_time": mean(retrieve_times),
|
||||
"total_reader_time": reader_total_time,
|
||||
"avg_reader_time": mean(read_times),
|
||||
"total_finder_time": finder_total_time
|
||||
}
|
||||
eval_results = self.calc_eval_results(counts)
|
||||
eval_results["total_retrieve_time"] = retriever_total_time
|
||||
eval_results["avg_retrieve_time"] = mean(retrieve_times)
|
||||
eval_results["total_reader_time"] = reader_total_time
|
||||
eval_results["avg_reader_time"] = mean(read_times)
|
||||
eval_results["total_finder_time"] = finder_total_time
|
||||
|
||||
return eval_results
|
||||
|
||||
def eval_batch(
|
||||
self,
|
||||
label_index: str,
|
||||
doc_index : str,
|
||||
label_origin: str = "gold_label",
|
||||
top_k_retriever: int = 10,
|
||||
top_k_reader: int = 10,
|
||||
batch_size: int = 50
|
||||
):
|
||||
"""
|
||||
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
|
||||
of the Retriever. Passes all retrieved question-document pairs to the Reader at once.
|
||||
Returns a dict containing the following metrics:
|
||||
- "retriever_recall": Proportion of questions for which correct document is among retrieved documents
|
||||
- "retriever_map": Mean of average precision for each question. Rewards retrievers that give relevant
|
||||
documents a higher rank.
|
||||
- "reader_top1_accuracy": Proportion of highest ranked predicted answers that overlap with corresponding correct answer
|
||||
- "reader_top1_accuracy_has_answer": Proportion of highest ranked predicted answers that overlap
|
||||
with corresponding correct answer for answerable questions
|
||||
- "reader_top_k_accuracy": Proportion of predicted answers that overlap with corresponding correct answer
|
||||
- "reader_topk_accuracy_has_answer": Proportion of predicted answers that overlap with corresponding correct answer
|
||||
for answerable questions
|
||||
- "reader_top1_em": Proportion of exact matches of highest ranked predicted answers with their corresponding
|
||||
correct answers
|
||||
- "reader_top1_em_has_answer": Proportion of exact matches of highest ranked predicted answers with their corresponding
|
||||
correct answers for answerable questions
|
||||
- "reader_topk_em": Proportion of exact matches of predicted answers with their corresponding correct answers
|
||||
- "reader_topk_em_has_answer": Proportion of exact matches of predicted answers with their corresponding
|
||||
correct answers for answerable questions
|
||||
- "reader_top1_f1": Average overlap between highest ranked predicted answers and their corresponding correct answers
|
||||
- "reader_top1_f1_has_answer": Average overlap between highest ranked predicted answers and their corresponding
|
||||
correct answers for answerable questions
|
||||
- "reader_topk_f1": Average overlap between predicted answers and their corresponding correct answers
|
||||
- "reader_topk_f1_has_answer": Average overlap between predicted answers and their corresponding correct answers
|
||||
for answerable questions
|
||||
- "reader_top1_no_answer_accuracy": Proportion of correct predicting unanswerable question at highest ranked prediction
|
||||
- "reader_topk_no_answer_accuracy": Proportion of correct predicting unanswerable question among all predictions
|
||||
- "total_retrieve_time": Time retriever needed to retrieve documents for all questions
|
||||
- "avg_retrieve_time": Average time needed to retrieve documents for one question
|
||||
- "total_reader_time": Time reader needed to extract answer out of retrieved documents for all questions
|
||||
where the correct document is among the retrieved ones
|
||||
- "avg_reader_time": Average time needed to extract answer out of retrieved documents for one question
|
||||
- "total_finder_time": Total time for whole pipeline
|
||||
:param label_index: Elasticsearch index where labeled questions are stored
|
||||
:type label_index: str
|
||||
:param doc_index: Elasticsearch index where documents that are used for evaluation are stored
|
||||
:type doc_index: str
|
||||
:param top_k_retriever: How many documents per question to return and pass to reader
|
||||
:type top_k_retriever: int
|
||||
:param top_k_reader: How many answers to return per question
|
||||
:type top_k_reader: int
|
||||
:param batch_size: Number of samples per batch computed at once
|
||||
:type batch_size: int
|
||||
"""
|
||||
|
||||
if not self.reader or not self.retriever:
|
||||
raise Exception("Finder needs to have a reader and retriever for the evalutaion.")
|
||||
|
||||
counts = defaultdict(float) # type: Dict[str, float]
|
||||
finder_start_time = time.time()
|
||||
|
||||
# extract all questions for evaluation
|
||||
filters = {"origin": [label_origin]}
|
||||
questions = self.retriever.document_store.get_all_labels_aggregated(index=label_index, filters=filters)
|
||||
number_of_questions = len(questions)
|
||||
|
||||
# retrieve documents
|
||||
retriever_start_time = time.time()
|
||||
questions_with_docs = self._retrieve_docs(questions, top_k=top_k_retriever, doc_index=doc_index)
|
||||
retriever_total_time = time.time() - retriever_start_time
|
||||
|
||||
questions_with_correct_doc, summed_avg_precision_retriever = calculate_average_precision(questions_with_docs)
|
||||
correct_retrievals = len(questions_with_correct_doc)
|
||||
|
||||
# extract answers
|
||||
previous_return_no_answers = self.reader.return_no_answers
|
||||
self.reader.return_no_answers = True
|
||||
reader_start_time = time.time()
|
||||
predictions = self.reader.predict_batch(questions_with_correct_doc,
|
||||
top_k_per_question=top_k_reader, batch_size=batch_size)
|
||||
reader_total_time = time.time() - reader_start_time
|
||||
|
||||
for pred in predictions:
|
||||
counts = eval_counts_reader_batch(pred, counts)
|
||||
|
||||
finder_total_time = time.time() - finder_start_time
|
||||
|
||||
results = calculate_reader_metrics(counts, correct_retrievals)
|
||||
results["retriever_recall"] = correct_retrievals / number_of_questions
|
||||
results["retriever_map"] = summed_avg_precision_retriever / number_of_questions
|
||||
results["total_retrieve_time"] = retriever_total_time
|
||||
results["avg_retrieve_time"] = retriever_total_time / number_of_questions
|
||||
results["total_reader_time"] = reader_total_time
|
||||
results["avg_reader_time"] = reader_total_time / correct_retrievals
|
||||
results["total_finder_time"] = finder_total_time
|
||||
|
||||
logger.info((f"{counts['correct_readings_topk']} out of {number_of_questions} questions were correctly "
|
||||
f"answered ({(counts['correct_readings_topk'] / number_of_questions):.2%})."))
|
||||
logger.info(f"{number_of_questions - correct_retrievals} questions could not be answered due to the retriever.")
|
||||
logger.info(f"{correct_retrievals - counts['correct_readings_topk']} questions could not be answered due to the reader.")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _retrieve_docs(self, questions: List[MultiLabel], top_k: int, doc_index: str):
|
||||
# Retrieves documents for a list of Labels (= questions)
|
||||
questions_with_docs = []
|
||||
|
||||
for question in questions:
|
||||
question_string = question.question
|
||||
retrieved_docs = self.retriever.retrieve(question_string, top_k=top_k, index=doc_index) # type: ignore
|
||||
questions_with_docs.append({
|
||||
"question": question,
|
||||
"docs": retrieved_docs
|
||||
})
|
||||
|
||||
return questions_with_docs
|
||||
|
||||
|
||||
@staticmethod
|
||||
def print_eval_results(finder_eval_results: Dict):
|
||||
print("\n___Retriever Metrics in Finder___")
|
||||
@ -379,9 +383,10 @@ class Finder:
|
||||
print(f"Reader Top-1 F1 (has answer) : {finder_eval_results['reader_top1_f1_has_answer']:.3f}")
|
||||
print(f"Reader Top-k F1 : {finder_eval_results['reader_topk_f1']:.3f}")
|
||||
print(f"Reader Top-k F1 (has answer) : {finder_eval_results['reader_topk_f1_has_answer']:.3f}")
|
||||
print("No Answer")
|
||||
print(f"Reader Top-1 no-answer accuracy : {finder_eval_results['reader_top1_no_answer_accuracy']:.3f}")
|
||||
print(f"Reader Top-k no-answer accuracy : {finder_eval_results['reader_topk_no_answer_accuracy']:.3f}")
|
||||
if finder_eval_results['reader_top1_no_answer_accuracy']:
|
||||
print("No Answer")
|
||||
print(f"Reader Top-1 no-answer accuracy : {finder_eval_results['reader_top1_no_answer_accuracy']:.3f}")
|
||||
print(f"Reader Top-k no-answer accuracy : {finder_eval_results['reader_topk_no_answer_accuracy']:.3f}")
|
||||
|
||||
# Time measurements
|
||||
print("\n___Time Measurements___")
|
||||
@ -391,3 +396,35 @@ class Finder:
|
||||
print(f"Avg read time per question : {finder_eval_results['avg_reader_time']:.3f}")
|
||||
print(f"Total Finder time : {finder_eval_results['total_finder_time']:.3f}")
|
||||
|
||||
@staticmethod
|
||||
def calc_eval_results(eval_counts: Dict):
|
||||
eval_results = {}
|
||||
number_of_questions = eval_counts["number_of_questions"]
|
||||
correct_retrievals = eval_counts["correct_retrievals"]
|
||||
number_of_has_answer = eval_counts["number_of_has_answer"]
|
||||
number_of_no_answer = eval_counts["number_of_no_answer"]
|
||||
|
||||
eval_results["retriever_recall"] = eval_counts["correct_retrievals"] / number_of_questions
|
||||
eval_results["retriever_map"] = eval_counts["summed_avg_precision_retriever"] / number_of_questions
|
||||
|
||||
eval_results["reader_top1_accuracy"] = eval_counts["correct_readings_top1"] / correct_retrievals
|
||||
eval_results["reader_top1_accuracy_has_answer"] = eval_counts["correct_readings_top1_has_answer"] / number_of_has_answer
|
||||
eval_results["reader_topk_accuracy"] = eval_counts["correct_readings_topk"] / correct_retrievals
|
||||
eval_results["reader_topk_accuracy_has_answer"] = eval_counts["correct_readings_topk_has_answer"] / number_of_has_answer
|
||||
eval_results["reader_top1_em"] = eval_counts["exact_matches_top1"] / correct_retrievals
|
||||
eval_results["reader_top1_em_has_answer"] = eval_counts["exact_matches_top1_has_answer"] / number_of_has_answer
|
||||
eval_results["reader_topk_em"] = eval_counts["exact_matches_topk"] / correct_retrievals
|
||||
eval_results["reader_topk_em_has_answer"] = eval_counts["exact_matches_topk_has_answer"] / number_of_has_answer
|
||||
eval_results["reader_top1_f1"] = eval_counts["summed_f1_top1"] / correct_retrievals
|
||||
eval_results["reader_top1_f1_has_answer"] = eval_counts["summed_f1_top1_has_answer"] / number_of_has_answer
|
||||
eval_results["reader_topk_f1"] = eval_counts["summed_f1_topk"] / correct_retrievals
|
||||
eval_results["reader_topk_f1_has_answer"] = eval_counts["summed_f1_topk_has_answer"] / number_of_has_answer
|
||||
if number_of_no_answer:
|
||||
eval_results["reader_top1_no_answer_accuracy"] = eval_counts["correct_no_answers_top1"] / number_of_no_answer
|
||||
eval_results["reader_topk_no_answer_accuracy"] = eval_counts["correct_no_answers_topk"] / number_of_no_answer
|
||||
else:
|
||||
eval_results["reader_top1_no_answer_accuracy"] = None
|
||||
eval_results["reader_topk_no_answer_accuracy"] = None
|
||||
|
||||
return eval_results
|
||||
|
||||
|
||||
@ -7,11 +7,17 @@ from haystack.database.base import Document
|
||||
|
||||
|
||||
class BaseReader(ABC):
|
||||
return_no_answers: bool
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict_batch(self, question_doc_list: List[dict], top_k_per_question: Optional[int] = None,
|
||||
batch_size: Optional[int] = None):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _calc_no_answer(no_ans_gaps: Sequence[float], best_score_answer: float):
|
||||
# "no answer" scores and positive answers scores are difficult to compare, because
|
||||
|
||||
@ -224,6 +224,65 @@ class FARMReader(BaseReader):
|
||||
self.inferencer.model.save(directory)
|
||||
self.inferencer.processor.save(directory)
|
||||
|
||||
def predict_batch(self, question_doc_list: List[dict], top_k_per_question: int = None, batch_size: int = None):
|
||||
"""
|
||||
Use loaded QA model to find answers for a list of questions in each question's supplied list of Document.
|
||||
|
||||
Returns list of dictionaries containing answers sorted by (desc.) probability
|
||||
|
||||
:param question_doc_list: List of dictionaries containing questions with their retrieved documents
|
||||
:param top_k_per_question: the maximum number of answers to return for each question
|
||||
:param batch_size: Number of samples the model receives in one batch for inference
|
||||
:return: List of dictionaries containing question and answers
|
||||
"""
|
||||
|
||||
# convert input to FARM format
|
||||
inputs = []
|
||||
number_of_docs = []
|
||||
labels = []
|
||||
|
||||
# build input objects for inference_from_objects
|
||||
for question_with_docs in question_doc_list:
|
||||
documents = question_with_docs["docs"]
|
||||
question = question_with_docs["question"]
|
||||
labels.append(question)
|
||||
number_of_docs.append(len(documents))
|
||||
|
||||
for doc in documents:
|
||||
cur = QAInput(doc_text=doc.text,
|
||||
questions=Question(text=question.question,
|
||||
uid=doc.id))
|
||||
inputs.append(cur)
|
||||
|
||||
self.inferencer.batch_size = batch_size
|
||||
# make predictions on all document-question pairs
|
||||
predictions = self.inferencer.inference_from_objects(
|
||||
objects=inputs, return_json=False, multiprocessing_chunksize=1
|
||||
)
|
||||
|
||||
# group predictions together
|
||||
grouped_predictions = []
|
||||
left_idx = 0
|
||||
right_idx = 0
|
||||
for number in number_of_docs:
|
||||
right_idx = left_idx + number
|
||||
grouped_predictions.append(predictions[left_idx:right_idx])
|
||||
left_idx = right_idx
|
||||
|
||||
result = []
|
||||
for idx, group in enumerate(grouped_predictions):
|
||||
answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k_per_question)
|
||||
question = group[0]
|
||||
cur_label = labels[idx]
|
||||
result.append({
|
||||
"question": question,
|
||||
"no_ans_gap": max_no_ans_gap,
|
||||
"answers": answers,
|
||||
"label": cur_label
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
"""
|
||||
Use loaded QA model to find answers for a question in the supplied list of Document.
|
||||
@ -263,46 +322,7 @@ class FARMReader(BaseReader):
|
||||
objects=inputs, return_json=False, multiprocessing_chunksize=1
|
||||
)
|
||||
# assemble answers from all the different documents & format them.
|
||||
# For the "no answer" option, we collect all no_ans_gaps and decide how likely
|
||||
# a no answer is based on all no_ans_gaps values across all documents
|
||||
answers = []
|
||||
no_ans_gaps = []
|
||||
best_score_answer = 0
|
||||
for pred in predictions:
|
||||
answers_per_document = []
|
||||
no_ans_gaps.append(pred.no_answer_gap)
|
||||
for ans in pred.prediction:
|
||||
# skip "no answers" here
|
||||
if self._check_no_answer(ans):
|
||||
pass
|
||||
else:
|
||||
cur = {"answer": ans.answer,
|
||||
"score": ans.score,
|
||||
# just a pseudo prob for now
|
||||
"probability": float(expit(np.asarray([ans.score]) / 8)), # type: ignore
|
||||
"context": ans.context_window,
|
||||
"offset_start": ans.offset_answer_start - ans.offset_context_window_start,
|
||||
"offset_end": ans.offset_answer_end - ans.offset_context_window_start,
|
||||
"offset_start_in_doc": ans.offset_answer_start,
|
||||
"offset_end_in_doc": ans.offset_answer_end,
|
||||
"document_id": pred.id}
|
||||
answers_per_document.append(cur)
|
||||
|
||||
if ans.score > best_score_answer:
|
||||
best_score_answer = ans.score
|
||||
# only take n best candidates. Answers coming back from FARM are sorted with decreasing relevance.
|
||||
answers += answers_per_document[:self.top_k_per_candidate]
|
||||
|
||||
# Calculate the score for predicting "no answer", relative to our best positive answer score
|
||||
no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps,best_score_answer)
|
||||
if self.return_no_answers:
|
||||
answers.append(no_ans_prediction)
|
||||
|
||||
# sort answers by their `probability` and select top-k
|
||||
answers = sorted(
|
||||
answers, key=lambda k: k["probability"], reverse=True
|
||||
)
|
||||
answers = answers[:top_k]
|
||||
answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k)
|
||||
result = {"question": question,
|
||||
"no_ans_gap": max_no_ans_gap,
|
||||
"answers": answers}
|
||||
@ -433,6 +453,56 @@ class FARMReader(BaseReader):
|
||||
}
|
||||
return results
|
||||
|
||||
def _extract_answers_of_predictions(self, predictions: List[QAPred], top_k: Optional[int] = None):
|
||||
# Assemble answers from all the different documents and format them.
|
||||
# For the 'no answer' option, we collect all no_ans_gaps and decide how likely
|
||||
# a no answer is based on all no_ans_gaps values across all documents
|
||||
answers = []
|
||||
no_ans_gaps = []
|
||||
best_score_answer = 0
|
||||
|
||||
for pred in predictions:
|
||||
answers_per_document = []
|
||||
no_ans_gaps.append(pred.no_answer_gap)
|
||||
for ans in pred.prediction:
|
||||
# skip 'no answers' here
|
||||
if self._check_no_answer(ans):
|
||||
pass
|
||||
else:
|
||||
cur = {
|
||||
"answer": ans.answer,
|
||||
"score": ans.score,
|
||||
# just a pseudo prob for now
|
||||
"probability": self._get_pseudo_prob(ans.score),
|
||||
"context": ans.context_window,
|
||||
"offset_start": ans.offset_answer_start - ans.offset_context_window_start,
|
||||
"offset_end": ans.offset_answer_end - ans.offset_context_window_start,
|
||||
"offset_start_in_doc": ans.offset_answer_start,
|
||||
"offset_end_in_doc": ans.offset_answer_end,
|
||||
"document_id": pred.id
|
||||
}
|
||||
answers_per_document.append(cur)
|
||||
|
||||
if ans.score > best_score_answer:
|
||||
best_score_answer = ans.score
|
||||
|
||||
# Only take n best candidates. Answers coming back from FARM are sorted with decreasing relevance
|
||||
answers += answers_per_document[:self.top_k_per_candidate]
|
||||
|
||||
# calculate the score for predicting 'no answer', relative to our best positive answer score
|
||||
no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps, best_score_answer)
|
||||
if self.return_no_answers:
|
||||
answers.append(no_ans_prediction)
|
||||
|
||||
# sort answers by score and select top-k
|
||||
answers = sorted(answers, key=lambda k: k["score"], reverse=True)
|
||||
answers = answers[:top_k]
|
||||
|
||||
return answers, max_no_ans_gap
|
||||
|
||||
@staticmethod
|
||||
def _get_pseudo_prob(score: float):
|
||||
return float(expit(np.asarray(score) / 8))
|
||||
|
||||
@staticmethod
|
||||
def _check_no_answer(c: QACandidate):
|
||||
@ -445,7 +515,6 @@ class FARMReader(BaseReader):
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None):
|
||||
documents = []
|
||||
for text in texts:
|
||||
|
||||
@ -51,7 +51,7 @@ class TransformersReader(BaseReader):
|
||||
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
|
||||
self.context_window_size = context_window_size
|
||||
self.top_k_per_candidate = top_k_per_candidate
|
||||
self.no_answer = no_answer
|
||||
self.return_no_answers = no_answer
|
||||
|
||||
# TODO context_window_size behaviour different from behavior in FARMReader
|
||||
|
||||
@ -87,7 +87,7 @@ class TransformersReader(BaseReader):
|
||||
best_overall_score = 0
|
||||
for doc in documents:
|
||||
query = {"context": doc.text, "question": question}
|
||||
predictions = self.model(query, topk=self.top_k_per_candidate, handle_impossible_answer=self.no_answer)
|
||||
predictions = self.model(query, topk=self.top_k_per_candidate, handle_impossible_answer=self.return_no_answers)
|
||||
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
|
||||
if type(predictions) == dict:
|
||||
predictions = [predictions]
|
||||
@ -124,7 +124,7 @@ class TransformersReader(BaseReader):
|
||||
# Calculate the score for predicting "no answer", relative to our best positive answer score
|
||||
no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps, best_overall_score)
|
||||
|
||||
if self.no_answer:
|
||||
if self.return_no_answers:
|
||||
answers.append(no_ans_prediction)
|
||||
# sort answers by their `probability` and select top-k
|
||||
answers = sorted(
|
||||
@ -136,3 +136,8 @@ class TransformersReader(BaseReader):
|
||||
"answers": answers}
|
||||
|
||||
return results
|
||||
|
||||
def predict_batch(self, question_doc_list: List[dict], top_k_per_question: Optional[int] = None,
|
||||
batch_size: Optional[int] = None):
|
||||
|
||||
raise NotImplementedError("Batch prediction not yet available in TransformersReader.")
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from haystack.database.base import BaseDocumentStore
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.finder import Finder
|
||||
|
||||
|
||||
def test_add_eval_data(document_store):
|
||||
@ -78,3 +79,40 @@ def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain):
|
||||
# clean up
|
||||
document_store.delete_all_documents(index="test_eval_document")
|
||||
document_store.delete_all_documents(index="test_feedback")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_eval_finder(document_store: BaseDocumentStore, reader):
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
finder = Finder(reader=reader, retriever=retriever)
|
||||
|
||||
# add eval data (SQUAD format)
|
||||
document_store.delete_all_documents(index="test_eval_document")
|
||||
document_store.delete_all_documents(index="test_feedback")
|
||||
document_store.add_eval_data(filename="samples/squad/tiny.json", doc_index="test_eval_document", label_index="test_feedback")
|
||||
assert document_store.get_document_count(index="test_eval_document") == 2
|
||||
|
||||
# eval finder
|
||||
results = finder.eval(label_index="test_feedback", doc_index="test_eval_document", top_k_retriever=1, top_k_reader=5)
|
||||
assert results["retriever_recall"] == 1.0
|
||||
assert results["retriever_map"] == 1.0
|
||||
assert abs(results["reader_topk_f1"] - 0.66666) < 0.001
|
||||
assert abs(results["reader_topk_em"] - 0.5) < 0.001
|
||||
assert abs(results["reader_topk_accuracy"] - 1) < 0.001
|
||||
assert results["reader_top1_f1"] <= results["reader_topk_f1"]
|
||||
assert results["reader_top1_em"] <= results["reader_topk_em"]
|
||||
assert results["reader_top1_accuracy"] <= results["reader_topk_accuracy"]
|
||||
|
||||
# batch eval finder
|
||||
results_batch = finder.eval_batch(label_index="test_feedback", doc_index="test_eval_document", top_k_retriever=1,
|
||||
top_k_reader=5)
|
||||
assert results_batch["retriever_recall"] == 1.0
|
||||
assert results_batch["retriever_map"] == 1.0
|
||||
assert results_batch["reader_top1_f1"] == results["reader_top1_f1"]
|
||||
assert results_batch["reader_top1_em"] == results["reader_top1_em"]
|
||||
assert results_batch["reader_topk_accuracy"] == results["reader_topk_accuracy"]
|
||||
|
||||
# clean up
|
||||
document_store.delete_all_documents(index="test_eval_document")
|
||||
document_store.delete_all_documents(index="test_feedback")
|
||||
Loading…
x
Reference in New Issue
Block a user