Add More top_k handling to EvalDocuments (#1133)

* Improve top_k support

* Adjust warning

* Satisfy mypy

* Reinit eval counts if top_k has changed

* Incorporate reviewer feedback
This commit is contained in:
Branden Chan 2021-06-07 12:11:00 +02:00 committed by GitHub
parent c513865566
commit 59e3c55c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,4 @@
from typing import List, Tuple, Dict, Any from typing import List, Tuple, Dict, Any, Optional
import logging import logging
from haystack import MultiLabel, Label from haystack import MultiLabel, Label
@ -18,7 +18,7 @@ class EvalDocuments:
a look at our evaluation tutorial for more info about open vs closed domain eval ( a look at our evaluation tutorial for more info about open vs closed domain eval (
https://haystack.deepset.ai/docs/latest/tutorial5md). https://haystack.deepset.ai/docs/latest/tutorial5md).
""" """
def __init__(self, debug: bool=False, open_domain: bool=True, top_k: int=10, name="EvalDocuments"): def __init__(self, debug: bool=False, open_domain: bool=True, top_k_eval_documents: int=10, name="EvalDocuments"):
""" """
:param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it. :param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it.
When False, correct retrieval is evaluated based on document_id. When False, correct retrieval is evaluated based on document_id.
@ -31,8 +31,10 @@ class EvalDocuments:
self.debug = debug self.debug = debug
self.log: List = [] self.log: List = []
self.open_domain = open_domain self.open_domain = open_domain
self.top_k = top_k self.top_k_eval_documents = top_k_eval_documents
self.name = name self.name = name
self.too_few_docs_warning = False
self.top_k_used = 0
def init_counts(self): def init_counts(self):
self.correct_retrieval_count = 0 self.correct_retrieval_count = 0
@ -47,10 +49,26 @@ class EvalDocuments:
self.reciprocal_rank_sum = 0.0 self.reciprocal_rank_sum = 0.0
self.has_answer_reciprocal_rank_sum = 0.0 self.has_answer_reciprocal_rank_sum = 0.0
def run(self, documents, labels: dict, **kwargs): def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None, **kwargs):
"""Run this node on one sample and its labels""" """Run this node on one sample and its labels"""
self.query_count += 1 self.query_count += 1
retriever_labels = get_label(labels, kwargs["node_id"]) retriever_labels = get_label(labels, kwargs["node_id"])
if not top_k_eval_documents:
top_k_eval_documents = self.top_k_eval_documents
if not self.top_k_used:
self.top_k_used = top_k_eval_documents
elif self.top_k_used != top_k_eval_documents:
logger.warning(f"EvalDocuments was last run with top_k_eval_documents={self.top_k_used} but is "
f"being run again with top_k_eval_documents={self.top_k_eval_documents}. "
f"The evaluation counter is being reset from this point so that the evaluation "
f"metrics are interpretable.")
self.init_counts()
if len(documents) < top_k_eval_documents and not self.too_few_docs_warning:
logger.warning(f"EvalDocuments is being provided less candidate documents than top_k_eval_documents "
f"(currently set to {top_k_eval_documents}).")
self.too_few_docs_warning = True
# TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object # TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object
# If this sample is impossible to answer and expects a no_answer response # If this sample is impossible to answer and expects a no_answer response
@ -67,7 +85,7 @@ class EvalDocuments:
# If there are answer span annotations in the labels # If there are answer span annotations in the labels
else: else:
self.has_answer_count += 1 self.has_answer_count += 1
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents) retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents, top_k_eval_documents)
self.reciprocal_rank_sum += retrieved_reciprocal_rank self.reciprocal_rank_sum += retrieved_reciprocal_rank
correct_retrieval = True if retrieved_reciprocal_rank > 0 else False correct_retrieval = True if retrieved_reciprocal_rank > 0 else False
self.has_answer_correct += int(correct_retrieval) self.has_answer_correct += int(correct_retrieval)
@ -79,6 +97,8 @@ class EvalDocuments:
self.recall = self.correct_retrieval_count / self.query_count self.recall = self.correct_retrieval_count / self.query_count
self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count
self.top_k_used = top_k_eval_documents
if self.debug: if self.debug:
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs}) self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs})
return {"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs}, "output_1" return {"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs}, "output_1"
@ -86,15 +106,15 @@ class EvalDocuments:
def is_correctly_retrieved(self, retriever_labels, predictions): def is_correctly_retrieved(self, retriever_labels, predictions):
return self.reciprocal_rank_retrieved(retriever_labels, predictions) > 0 return self.reciprocal_rank_retrieved(retriever_labels, predictions) > 0
def reciprocal_rank_retrieved(self, retriever_labels, predictions): def reciprocal_rank_retrieved(self, retriever_labels, predictions, top_k_eval_documents):
if self.open_domain: if self.open_domain:
for label in retriever_labels.multiple_answers: for label in retriever_labels.multiple_answers:
for rank, p in enumerate(predictions[:self.top_k]): for rank, p in enumerate(predictions[:top_k_eval_documents]):
if label.lower() in p.text.lower(): if label.lower() in p.text.lower():
return 1/(rank+1) return 1/(rank+1)
return False return False
else: else:
prediction_ids = [p.id for p in predictions[:self.top_k]] prediction_ids = [p.id for p in predictions[:top_k_eval_documents]]
label_ids = retriever_labels.multiple_document_ids label_ids = retriever_labels.multiple_document_ids
for rank, p in enumerate(prediction_ids): for rank, p in enumerate(prediction_ids):
if p in label_ids: if p in label_ids:
@ -107,15 +127,15 @@ class EvalDocuments:
print("-----------------") print("-----------------")
if self.no_answer_count: if self.no_answer_count:
print( print(
f"has_answer recall@{self.top_k}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})") f"has_answer recall@{self.top_k_used}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})")
print( print(
f"no_answer recall@{self.top_k}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)") f"no_answer recall@{self.top_k_used}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)")
print( print(
f"has_answer mean_reciprocal_rank@{self.top_k}: {self.has_answer_mean_reciprocal_rank:.4f}") f"has_answer mean_reciprocal_rank@{self.top_k_used}: {self.has_answer_mean_reciprocal_rank:.4f}")
print( print(
f"no_answer mean_reciprocal_rank@{self.top_k}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)") f"no_answer mean_reciprocal_rank@{self.top_k_used}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)")
print(f"recall@{self.top_k}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})") print(f"recall@{self.top_k_used}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})")
print(f"mean_reciprocal_rank@{self.top_k}: {self.mean_reciprocal_rank:.4f}") print(f"mean_reciprocal_rank@{self.top_k_used}: {self.mean_reciprocal_rank:.4f}")
class EvalAnswers: class EvalAnswers: