From 59e3c55c470bbd3e7b22fd6f509319a6b1eb41b5 Mon Sep 17 00:00:00 2001 From: Branden Chan <33759007+brandenchan@users.noreply.github.com> Date: Mon, 7 Jun 2021 12:11:00 +0200 Subject: [PATCH] Add More top_k handling to EvalDocuments (#1133) * Improve top_k support * Adjust warning * Satisfy mypy * Reinit eval counts if top_k has changed * Incorporate reviewer feedback --- haystack/eval.py | 48 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/haystack/eval.py b/haystack/eval.py index 6d5f3121a..42dc65684 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Dict, Any +from typing import List, Tuple, Dict, Any, Optional import logging from haystack import MultiLabel, Label @@ -18,7 +18,7 @@ class EvalDocuments: a look at our evaluation tutorial for more info about open vs closed domain eval ( https://haystack.deepset.ai/docs/latest/tutorial5md). """ - def __init__(self, debug: bool=False, open_domain: bool=True, top_k: int=10, name="EvalDocuments"): + def __init__(self, debug: bool=False, open_domain: bool=True, top_k_eval_documents: int=10, name="EvalDocuments"): """ :param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it. When False, correct retrieval is evaluated based on document_id. @@ -31,8 +31,10 @@ class EvalDocuments: self.debug = debug self.log: List = [] self.open_domain = open_domain - self.top_k = top_k + self.top_k_eval_documents = top_k_eval_documents self.name = name + self.too_few_docs_warning = False + self.top_k_used = 0 def init_counts(self): self.correct_retrieval_count = 0 @@ -47,10 +49,26 @@ class EvalDocuments: self.reciprocal_rank_sum = 0.0 self.has_answer_reciprocal_rank_sum = 0.0 - def run(self, documents, labels: dict, **kwargs): + def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None, **kwargs): """Run this node on one sample and its labels""" self.query_count += 1 retriever_labels = get_label(labels, kwargs["node_id"]) + if not top_k_eval_documents: + top_k_eval_documents = self.top_k_eval_documents + + if not self.top_k_used: + self.top_k_used = top_k_eval_documents + elif self.top_k_used != top_k_eval_documents: + logger.warning(f"EvalDocuments was last run with top_k_eval_documents={self.top_k_used} but is " + f"being run again with top_k_eval_documents={self.top_k_eval_documents}. " + f"The evaluation counter is being reset from this point so that the evaluation " + f"metrics are interpretable.") + self.init_counts() + + if len(documents) < top_k_eval_documents and not self.too_few_docs_warning: + logger.warning(f"EvalDocuments is being provided less candidate documents than top_k_eval_documents " + f"(currently set to {top_k_eval_documents}).") + self.too_few_docs_warning = True # TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object # If this sample is impossible to answer and expects a no_answer response @@ -67,7 +85,7 @@ class EvalDocuments: # If there are answer span annotations in the labels else: self.has_answer_count += 1 - retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents) + retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents, top_k_eval_documents) self.reciprocal_rank_sum += retrieved_reciprocal_rank correct_retrieval = True if retrieved_reciprocal_rank > 0 else False self.has_answer_correct += int(correct_retrieval) @@ -78,6 +96,8 @@ class EvalDocuments: self.correct_retrieval_count += correct_retrieval self.recall = self.correct_retrieval_count / self.query_count self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count + + self.top_k_used = top_k_eval_documents if self.debug: self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs}) @@ -86,15 +106,15 @@ class EvalDocuments: def is_correctly_retrieved(self, retriever_labels, predictions): return self.reciprocal_rank_retrieved(retriever_labels, predictions) > 0 - def reciprocal_rank_retrieved(self, retriever_labels, predictions): + def reciprocal_rank_retrieved(self, retriever_labels, predictions, top_k_eval_documents): if self.open_domain: for label in retriever_labels.multiple_answers: - for rank, p in enumerate(predictions[:self.top_k]): + for rank, p in enumerate(predictions[:top_k_eval_documents]): if label.lower() in p.text.lower(): return 1/(rank+1) return False else: - prediction_ids = [p.id for p in predictions[:self.top_k]] + prediction_ids = [p.id for p in predictions[:top_k_eval_documents]] label_ids = retriever_labels.multiple_document_ids for rank, p in enumerate(prediction_ids): if p in label_ids: @@ -107,15 +127,15 @@ class EvalDocuments: print("-----------------") if self.no_answer_count: print( - f"has_answer recall@{self.top_k}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})") + f"has_answer recall@{self.top_k_used}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})") print( - f"no_answer recall@{self.top_k}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)") + f"no_answer recall@{self.top_k_used}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)") print( - f"has_answer mean_reciprocal_rank@{self.top_k}: {self.has_answer_mean_reciprocal_rank:.4f}") + f"has_answer mean_reciprocal_rank@{self.top_k_used}: {self.has_answer_mean_reciprocal_rank:.4f}") print( - f"no_answer mean_reciprocal_rank@{self.top_k}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)") - print(f"recall@{self.top_k}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})") - print(f"mean_reciprocal_rank@{self.top_k}: {self.mean_reciprocal_rank:.4f}") + f"no_answer mean_reciprocal_rank@{self.top_k_used}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)") + print(f"recall@{self.top_k_used}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})") + print(f"mean_reciprocal_rank@{self.top_k_used}: {self.mean_reciprocal_rank:.4f}") class EvalAnswers: