Fix retriever evaluation metrics (#547)

* Add mean reciprocal rank and fix mean average precision

* Add mrr metric to docstring

* Fix mypy error
This commit is contained in:
bogdankostic 2020-11-05 13:34:47 +01:00 committed by GitHub
parent 53be92c155
commit ffaa0249f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 83 additions and 30 deletions

View File

@ -33,22 +33,33 @@ def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals
return metrics return metrics
def calculate_average_precision(questions_with_docs: List[dict]): def calculate_average_precision_and_reciprocal_rank(questions_with_docs: List[dict]):
questions_with_correct_doc = [] questions_with_correct_doc = []
summed_avg_precision_retriever = 0.0 summed_avg_precision_retriever = 0.0
summed_reciprocal_rank_retriever = 0.0
for question in questions_with_docs: for question in questions_with_docs:
number_relevant_docs = len(set(question["question"].multiple_document_ids))
found_relevant_doc = False
relevant_docs_found = 0
for doc_idx, doc in enumerate(question["docs"]): for doc_idx, doc in enumerate(question["docs"]):
# check if correct doc among retrieved docs # check if correct doc among retrieved docs
if doc.id in question["question"].multiple_document_ids: if doc.id in question["question"].multiple_document_ids:
summed_avg_precision_retriever += 1 / (doc_idx + 1) if not found_relevant_doc:
questions_with_correct_doc.append({ summed_reciprocal_rank_retriever += 1 / (doc_idx + 1)
"question": question["question"], relevant_docs_found += 1
"docs": question["docs"] found_relevant_doc = True
}) summed_avg_precision_retriever += (1 / number_relevant_docs) * (relevant_docs_found / (doc_idx + 1))
break if relevant_docs_found == number_relevant_docs:
break
return questions_with_correct_doc, summed_avg_precision_retriever if found_relevant_doc:
questions_with_correct_doc.append({
"question": question["question"],
"docs": question["docs"]
})
return questions_with_correct_doc, summed_avg_precision_retriever, summed_reciprocal_rank_retriever
def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]): def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]):

View File

@ -8,8 +8,8 @@ from collections import defaultdict
from haystack.reader.base import BaseReader from haystack.reader.base import BaseReader
from haystack.retriever.base import BaseRetriever from haystack.retriever.base import BaseRetriever
from haystack import MultiLabel from haystack import MultiLabel
from haystack.eval import calculate_average_precision, eval_counts_reader_batch, calculate_reader_metrics, \ from haystack.eval import calculate_average_precision_and_reciprocal_rank, eval_counts_reader_batch, \
eval_counts_reader calculate_reader_metrics, eval_counts_reader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -131,7 +131,9 @@ class Finder:
Returns a dict containing the following metrics: Returns a dict containing the following metrics:
- ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents - ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents
- ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant - ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant
documents a higher rank. documents a higher rank. Considers all retrieved relevant documents.
- ``"retriever_mrr"``: Mean of reciprocal rank for each question. Rewards retrievers that give relevant
documents a higher rank. Only considers the highest ranked relevant document.
- ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer - ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer
- ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap - ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap
with corresponding correct answer for answerable questions with corresponding correct answer for answerable questions
@ -193,17 +195,28 @@ class Finder:
single_retrieve_start = time.time() single_retrieve_start = time.time()
retrieved_docs = self.retriever.retrieve(question_string, top_k=top_k_retriever, index=doc_index) retrieved_docs = self.retriever.retrieve(question_string, top_k=top_k_retriever, index=doc_index)
retrieve_times.append(time.time() - single_retrieve_start) retrieve_times.append(time.time() - single_retrieve_start)
number_relevant_docs = len(set(question.multiple_document_ids))
# check if correct doc among retrieved docs # check if correct doc among retrieved docs
found_relevant_doc = False
relevant_docs_found = 0
for doc_idx, doc in enumerate(retrieved_docs): for doc_idx, doc in enumerate(retrieved_docs):
if doc.id in question.multiple_document_ids: if doc.id in question.multiple_document_ids:
counts["correct_retrievals"] += 1 relevant_docs_found += 1
counts["summed_avg_precision_retriever"] += 1 / (doc_idx + 1) if not found_relevant_doc:
questions_with_docs.append({ counts["correct_retrievals"] += 1
"question": question, counts["summed_reciprocal_rank_retriever"] += 1 / (doc_idx + 1)
"docs": retrieved_docs counts["summed_avg_precision_retriever"] += (1 / number_relevant_docs) \
}) * (relevant_docs_found / (doc_idx + 1))
break found_relevant_doc = True
if relevant_docs_found == number_relevant_docs:
break
if found_relevant_doc:
questions_with_docs.append({
"question": question,
"docs": retrieved_docs
})
retriever_total_time = time.time() - retriever_start_time retriever_total_time = time.time() - retriever_start_time
counts["number_of_questions"] = q_idx + 1 counts["number_of_questions"] = q_idx + 1
@ -270,7 +283,9 @@ class Finder:
Returns a dict containing the following metrics: Returns a dict containing the following metrics:
- ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents - ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents
- ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant - ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant
documents a higher rank. documents a higher rank. Considers all retrieved relevant documents.
- ``"retriever_mrr"``: Mean of reciprocal rank for each question. Rewards retrievers that give relevant
documents a higher rank. Only considers the highest ranked relevant document.
- ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer - ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer
- ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap - ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap
with corresponding correct answer for answerable questions with corresponding correct answer for answerable questions
@ -330,7 +345,10 @@ class Finder:
questions_with_docs = self._retrieve_docs(questions, top_k=top_k_retriever, doc_index=doc_index) questions_with_docs = self._retrieve_docs(questions, top_k=top_k_retriever, doc_index=doc_index)
retriever_total_time = time.time() - retriever_start_time retriever_total_time = time.time() - retriever_start_time
questions_with_correct_doc, summed_avg_precision_retriever = calculate_average_precision(questions_with_docs) questions_with_correct_doc, \
summed_avg_precision_retriever, \
summed_reciprocal_rank_retriever = calculate_average_precision_and_reciprocal_rank(questions_with_docs)
correct_retrievals = len(questions_with_correct_doc) correct_retrievals = len(questions_with_correct_doc)
# extract answers # extract answers
@ -349,6 +367,7 @@ class Finder:
results = calculate_reader_metrics(counts, correct_retrievals) results = calculate_reader_metrics(counts, correct_retrievals)
results["retriever_recall"] = correct_retrievals / number_of_questions results["retriever_recall"] = correct_retrievals / number_of_questions
results["retriever_map"] = summed_avg_precision_retriever / number_of_questions results["retriever_map"] = summed_avg_precision_retriever / number_of_questions
results["retriever_mrr"] = summed_reciprocal_rank_retriever / number_of_questions
results["total_retrieve_time"] = retriever_total_time results["total_retrieve_time"] = retriever_total_time
results["avg_retrieve_time"] = retriever_total_time / number_of_questions results["avg_retrieve_time"] = retriever_total_time / number_of_questions
results["total_reader_time"] = reader_total_time results["total_reader_time"] = reader_total_time
@ -389,6 +408,7 @@ class Finder:
print("\n___Retriever Metrics in Finder___") print("\n___Retriever Metrics in Finder___")
print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}") print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}")
print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}") print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}")
print(f"Retriever Mean Reciprocal Rank: {finder_eval_results['retriever_mrr']:.3f}")
# Reader is only evaluated with those questions, where the correct document is among the retrieved ones # Reader is only evaluated with those questions, where the correct document is among the retrieved ones
print("\n___Reader Metrics in Finder___") print("\n___Reader Metrics in Finder___")
@ -430,6 +450,7 @@ class Finder:
eval_results["retriever_recall"] = eval_counts["correct_retrievals"] / number_of_questions eval_results["retriever_recall"] = eval_counts["correct_retrievals"] / number_of_questions
eval_results["retriever_map"] = eval_counts["summed_avg_precision_retriever"] / number_of_questions eval_results["retriever_map"] = eval_counts["summed_avg_precision_retriever"] / number_of_questions
eval_results["retriever_mrr"] = eval_counts["summed_reciprocal_rank_retriever"] / number_of_questions
eval_results["reader_top1_accuracy"] = eval_counts["correct_readings_top1"] / correct_retrievals eval_results["reader_top1_accuracy"] = eval_counts["correct_readings_top1"] / correct_retrievals
eval_results["reader_top1_accuracy_has_answer"] = eval_counts["correct_readings_top1_has_answer"] / number_of_has_answer eval_results["reader_top1_accuracy_has_answer"] = eval_counts["correct_readings_top1_has_answer"] / number_of_has_answer

View File

@ -56,8 +56,10 @@ class BaseRetriever(ABC):
| Returns a dict containing the following metrics: | Returns a dict containing the following metrics:
- "recall": Proportion of questions for which correct document is among retrieved documents - "recall": Proportion of questions for which correct document is among retrieved documents
- "mean avg precision": Mean of average precision for each question. Rewards retrievers that give relevant - "mrr": Mean of reciprocal rank. Rewards retrievers that give relevant documents a higher rank.
documents a higher rank. Only considers the highest ranked relevant document.
- "map": Mean of average precision for each question. Rewards retrievers that give relevant
documents a higher rank. Considers all retrieved relevant documents. (only with ``open_domain=False``)
:param label_index: Index/Table in DocumentStore where labeled questions are stored :param label_index: Index/Table in DocumentStore where labeled questions are stored
:param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored :param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored
@ -78,7 +80,8 @@ class BaseRetriever(ABC):
labels = self.document_store.get_all_labels_aggregated(index=label_index, filters=filters) labels = self.document_store.get_all_labels_aggregated(index=label_index, filters=filters)
correct_retrievals = 0 correct_retrievals = 0
summed_avg_precision = 0 summed_avg_precision = 0.0
summed_reciprocal_rank = 0.0
# Collect questions and corresponding answers/document_ids in a dict # Collect questions and corresponding answers/document_ids in a dict
question_label_dict = {} question_label_dict = {}
@ -99,12 +102,18 @@ class BaseRetriever(ABC):
if return_preds: if return_preds:
predictions.append({"question": question, "retrieved_docs": retrieved_docs}) predictions.append({"question": question, "retrieved_docs": retrieved_docs})
# check if correct doc in retrieved docs # check if correct doc in retrieved docs
found_relevant_doc = False
for doc_idx, doc in enumerate(retrieved_docs): for doc_idx, doc in enumerate(retrieved_docs):
for gold_answer in gold_answers: for gold_answer in gold_answers:
if gold_answer in doc.text: if gold_answer in doc.text:
correct_retrievals += 1 if not found_relevant_doc:
summed_avg_precision += 1 / (doc_idx + 1) # type: ignore correct_retrievals += 1
summed_reciprocal_rank += 1 / (doc_idx + 1)
found_relevant_doc = True
break break
# For the metrics in the open-domain case we are only considering the highest ranked relevant doc
if found_relevant_doc:
break
# Option 2: Strict evaluation by document ids that are listed in the labels # Option 2: Strict evaluation by document ids that are listed in the labels
else: else:
for question, gold_ids in tqdm(question_label_dict.items()): for question, gold_ids in tqdm(question_label_dict.items()):
@ -112,28 +121,38 @@ class BaseRetriever(ABC):
if return_preds: if return_preds:
predictions.append({"question": question, "retrieved_docs": retrieved_docs}) predictions.append({"question": question, "retrieved_docs": retrieved_docs})
# check if correct doc in retrieved docs # check if correct doc in retrieved docs
relevant_docs_found = 0
found_relevant_doc = False
for doc_idx, doc in enumerate(retrieved_docs): for doc_idx, doc in enumerate(retrieved_docs):
for gold_id in gold_ids: for gold_id in gold_ids:
if str(doc.id) == gold_id: if str(doc.id) == gold_id:
correct_retrievals += 1 if not found_relevant_doc:
summed_avg_precision += 1 / (doc_idx + 1) # type: ignore correct_retrievals += 1
summed_reciprocal_rank += 1 / (doc_idx + 1)
found_relevant_doc = True
relevant_docs_found += 1
summed_avg_precision += (1 / len(gold_ids)) * (relevant_docs_found / (doc_idx + 1))
break break
# Metrics # Metrics
number_of_questions = len(question_label_dict) number_of_questions = len(question_label_dict)
recall = correct_retrievals / number_of_questions recall = correct_retrievals / number_of_questions
mean_avg_precision = summed_avg_precision / number_of_questions mean_reciprocal_rank = summed_reciprocal_rank / number_of_questions
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in" logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
f" the top-{top_k} candidate passages selected by the retriever.")) f" the top-{top_k} candidate passages selected by the retriever."))
metrics = { metrics = {
"recall": recall, "recall": recall,
"map": mean_avg_precision, "mrr": mean_reciprocal_rank,
"retrieve_time": self.retrieve_time, "retrieve_time": self.retrieve_time,
"n_questions": number_of_questions, "n_questions": number_of_questions,
"top_k": top_k "top_k": top_k
} }
if not open_domain:
mean_avg_precision = summed_avg_precision / number_of_questions
metrics["map"] = mean_avg_precision
if return_preds: if return_preds:
return {"metrics": metrics, "predictions": predictions} return {"metrics": metrics, "predictions": predictions}
else: else:

View File

@ -75,7 +75,9 @@ def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain,
# eval retriever # eval retriever
results = retriever.eval(top_k=1, label_index="test_feedback", doc_index="test_eval_document", open_domain=open_domain) results = retriever.eval(top_k=1, label_index="test_feedback", doc_index="test_eval_document", open_domain=open_domain)
assert results["recall"] == 1.0 assert results["recall"] == 1.0
assert results["map"] == 1.0 assert results["mrr"] == 1.0
if not open_domain:
assert results["map"] == 1.0
# clean up # clean up
document_store.delete_all_documents(index="test_eval_document") document_store.delete_all_documents(index="test_eval_document")