mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 16:15:35 +00:00
Add More top_k handling to EvalDocuments (#1133)
* Improve top_k support * Adjust warning * Satisfy mypy * Reinit eval counts if top_k has changed * Incorporate reviewer feedback
This commit is contained in:
parent
c513865566
commit
59e3c55c47
@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple, Dict, Any
|
||||
from typing import List, Tuple, Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
from haystack import MultiLabel, Label
|
||||
@ -18,7 +18,7 @@ class EvalDocuments:
|
||||
a look at our evaluation tutorial for more info about open vs closed domain eval (
|
||||
https://haystack.deepset.ai/docs/latest/tutorial5md).
|
||||
"""
|
||||
def __init__(self, debug: bool=False, open_domain: bool=True, top_k: int=10, name="EvalDocuments"):
|
||||
def __init__(self, debug: bool=False, open_domain: bool=True, top_k_eval_documents: int=10, name="EvalDocuments"):
|
||||
"""
|
||||
:param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it.
|
||||
When False, correct retrieval is evaluated based on document_id.
|
||||
@ -31,8 +31,10 @@ class EvalDocuments:
|
||||
self.debug = debug
|
||||
self.log: List = []
|
||||
self.open_domain = open_domain
|
||||
self.top_k = top_k
|
||||
self.top_k_eval_documents = top_k_eval_documents
|
||||
self.name = name
|
||||
self.too_few_docs_warning = False
|
||||
self.top_k_used = 0
|
||||
|
||||
def init_counts(self):
|
||||
self.correct_retrieval_count = 0
|
||||
@ -47,10 +49,26 @@ class EvalDocuments:
|
||||
self.reciprocal_rank_sum = 0.0
|
||||
self.has_answer_reciprocal_rank_sum = 0.0
|
||||
|
||||
def run(self, documents, labels: dict, **kwargs):
|
||||
def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None, **kwargs):
|
||||
"""Run this node on one sample and its labels"""
|
||||
self.query_count += 1
|
||||
retriever_labels = get_label(labels, kwargs["node_id"])
|
||||
if not top_k_eval_documents:
|
||||
top_k_eval_documents = self.top_k_eval_documents
|
||||
|
||||
if not self.top_k_used:
|
||||
self.top_k_used = top_k_eval_documents
|
||||
elif self.top_k_used != top_k_eval_documents:
|
||||
logger.warning(f"EvalDocuments was last run with top_k_eval_documents={self.top_k_used} but is "
|
||||
f"being run again with top_k_eval_documents={self.top_k_eval_documents}. "
|
||||
f"The evaluation counter is being reset from this point so that the evaluation "
|
||||
f"metrics are interpretable.")
|
||||
self.init_counts()
|
||||
|
||||
if len(documents) < top_k_eval_documents and not self.too_few_docs_warning:
|
||||
logger.warning(f"EvalDocuments is being provided less candidate documents than top_k_eval_documents "
|
||||
f"(currently set to {top_k_eval_documents}).")
|
||||
self.too_few_docs_warning = True
|
||||
|
||||
# TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object
|
||||
# If this sample is impossible to answer and expects a no_answer response
|
||||
@ -67,7 +85,7 @@ class EvalDocuments:
|
||||
# If there are answer span annotations in the labels
|
||||
else:
|
||||
self.has_answer_count += 1
|
||||
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents)
|
||||
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents, top_k_eval_documents)
|
||||
self.reciprocal_rank_sum += retrieved_reciprocal_rank
|
||||
correct_retrieval = True if retrieved_reciprocal_rank > 0 else False
|
||||
self.has_answer_correct += int(correct_retrieval)
|
||||
@ -78,6 +96,8 @@ class EvalDocuments:
|
||||
self.correct_retrieval_count += correct_retrieval
|
||||
self.recall = self.correct_retrieval_count / self.query_count
|
||||
self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count
|
||||
|
||||
self.top_k_used = top_k_eval_documents
|
||||
|
||||
if self.debug:
|
||||
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs})
|
||||
@ -86,15 +106,15 @@ class EvalDocuments:
|
||||
def is_correctly_retrieved(self, retriever_labels, predictions):
|
||||
return self.reciprocal_rank_retrieved(retriever_labels, predictions) > 0
|
||||
|
||||
def reciprocal_rank_retrieved(self, retriever_labels, predictions):
|
||||
def reciprocal_rank_retrieved(self, retriever_labels, predictions, top_k_eval_documents):
|
||||
if self.open_domain:
|
||||
for label in retriever_labels.multiple_answers:
|
||||
for rank, p in enumerate(predictions[:self.top_k]):
|
||||
for rank, p in enumerate(predictions[:top_k_eval_documents]):
|
||||
if label.lower() in p.text.lower():
|
||||
return 1/(rank+1)
|
||||
return False
|
||||
else:
|
||||
prediction_ids = [p.id for p in predictions[:self.top_k]]
|
||||
prediction_ids = [p.id for p in predictions[:top_k_eval_documents]]
|
||||
label_ids = retriever_labels.multiple_document_ids
|
||||
for rank, p in enumerate(prediction_ids):
|
||||
if p in label_ids:
|
||||
@ -107,15 +127,15 @@ class EvalDocuments:
|
||||
print("-----------------")
|
||||
if self.no_answer_count:
|
||||
print(
|
||||
f"has_answer recall@{self.top_k}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})")
|
||||
f"has_answer recall@{self.top_k_used}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})")
|
||||
print(
|
||||
f"no_answer recall@{self.top_k}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)")
|
||||
f"no_answer recall@{self.top_k_used}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)")
|
||||
print(
|
||||
f"has_answer mean_reciprocal_rank@{self.top_k}: {self.has_answer_mean_reciprocal_rank:.4f}")
|
||||
f"has_answer mean_reciprocal_rank@{self.top_k_used}: {self.has_answer_mean_reciprocal_rank:.4f}")
|
||||
print(
|
||||
f"no_answer mean_reciprocal_rank@{self.top_k}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)")
|
||||
print(f"recall@{self.top_k}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})")
|
||||
print(f"mean_reciprocal_rank@{self.top_k}: {self.mean_reciprocal_rank:.4f}")
|
||||
f"no_answer mean_reciprocal_rank@{self.top_k_used}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)")
|
||||
print(f"recall@{self.top_k_used}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})")
|
||||
print(f"mean_reciprocal_rank@{self.top_k_used}: {self.mean_reciprocal_rank:.4f}")
|
||||
|
||||
|
||||
class EvalAnswers:
|
||||
|
Loading…
x
Reference in New Issue
Block a user