Add More top_k handling to EvalDocuments (#1133)

* Improve top_k support

* Adjust warning

* Satisfy mypy

* Reinit eval counts if top_k has changed

* Incorporate reviewer feedback
This commit is contained in:
Branden Chan 2021-06-07 12:11:00 +02:00 committed by GitHub
parent c513865566
commit 59e3c55c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,4 @@
from typing import List, Tuple, Dict, Any
from typing import List, Tuple, Dict, Any, Optional
import logging
from haystack import MultiLabel, Label
@ -18,7 +18,7 @@ class EvalDocuments:
a look at our evaluation tutorial for more info about open vs closed domain eval (
https://haystack.deepset.ai/docs/latest/tutorial5md).
"""
def __init__(self, debug: bool=False, open_domain: bool=True, top_k: int=10, name="EvalDocuments"):
def __init__(self, debug: bool=False, open_domain: bool=True, top_k_eval_documents: int=10, name="EvalDocuments"):
"""
:param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it.
When False, correct retrieval is evaluated based on document_id.
@ -31,8 +31,10 @@ class EvalDocuments:
self.debug = debug
self.log: List = []
self.open_domain = open_domain
self.top_k = top_k
self.top_k_eval_documents = top_k_eval_documents
self.name = name
self.too_few_docs_warning = False
self.top_k_used = 0
def init_counts(self):
self.correct_retrieval_count = 0
@ -47,10 +49,26 @@ class EvalDocuments:
self.reciprocal_rank_sum = 0.0
self.has_answer_reciprocal_rank_sum = 0.0
def run(self, documents, labels: dict, **kwargs):
def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None, **kwargs):
"""Run this node on one sample and its labels"""
self.query_count += 1
retriever_labels = get_label(labels, kwargs["node_id"])
if not top_k_eval_documents:
top_k_eval_documents = self.top_k_eval_documents
if not self.top_k_used:
self.top_k_used = top_k_eval_documents
elif self.top_k_used != top_k_eval_documents:
logger.warning(f"EvalDocuments was last run with top_k_eval_documents={self.top_k_used} but is "
f"being run again with top_k_eval_documents={self.top_k_eval_documents}. "
f"The evaluation counter is being reset from this point so that the evaluation "
f"metrics are interpretable.")
self.init_counts()
if len(documents) < top_k_eval_documents and not self.too_few_docs_warning:
logger.warning(f"EvalDocuments is being provided less candidate documents than top_k_eval_documents "
f"(currently set to {top_k_eval_documents}).")
self.too_few_docs_warning = True
# TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object
# If this sample is impossible to answer and expects a no_answer response
@ -67,7 +85,7 @@ class EvalDocuments:
# If there are answer span annotations in the labels
else:
self.has_answer_count += 1
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents)
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents, top_k_eval_documents)
self.reciprocal_rank_sum += retrieved_reciprocal_rank
correct_retrieval = True if retrieved_reciprocal_rank > 0 else False
self.has_answer_correct += int(correct_retrieval)
@ -78,6 +96,8 @@ class EvalDocuments:
self.correct_retrieval_count += correct_retrieval
self.recall = self.correct_retrieval_count / self.query_count
self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count
self.top_k_used = top_k_eval_documents
if self.debug:
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs})
@ -86,15 +106,15 @@ class EvalDocuments:
def is_correctly_retrieved(self, retriever_labels, predictions):
return self.reciprocal_rank_retrieved(retriever_labels, predictions) > 0
def reciprocal_rank_retrieved(self, retriever_labels, predictions):
def reciprocal_rank_retrieved(self, retriever_labels, predictions, top_k_eval_documents):
if self.open_domain:
for label in retriever_labels.multiple_answers:
for rank, p in enumerate(predictions[:self.top_k]):
for rank, p in enumerate(predictions[:top_k_eval_documents]):
if label.lower() in p.text.lower():
return 1/(rank+1)
return False
else:
prediction_ids = [p.id for p in predictions[:self.top_k]]
prediction_ids = [p.id for p in predictions[:top_k_eval_documents]]
label_ids = retriever_labels.multiple_document_ids
for rank, p in enumerate(prediction_ids):
if p in label_ids:
@ -107,15 +127,15 @@ class EvalDocuments:
print("-----------------")
if self.no_answer_count:
print(
f"has_answer recall@{self.top_k}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})")
f"has_answer recall@{self.top_k_used}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})")
print(
f"no_answer recall@{self.top_k}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)")
f"no_answer recall@{self.top_k_used}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)")
print(
f"has_answer mean_reciprocal_rank@{self.top_k}: {self.has_answer_mean_reciprocal_rank:.4f}")
f"has_answer mean_reciprocal_rank@{self.top_k_used}: {self.has_answer_mean_reciprocal_rank:.4f}")
print(
f"no_answer mean_reciprocal_rank@{self.top_k}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)")
print(f"recall@{self.top_k}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})")
print(f"mean_reciprocal_rank@{self.top_k}: {self.mean_reciprocal_rank:.4f}")
f"no_answer mean_reciprocal_rank@{self.top_k_used}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)")
print(f"recall@{self.top_k_used}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})")
print(f"mean_reciprocal_rank@{self.top_k_used}: {self.mean_reciprocal_rank:.4f}")
class EvalAnswers: