mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 16:46:58 +00:00
Add More top_k handling to EvalDocuments (#1133)
* Improve top_k support * Adjust warning * Satisfy mypy * Reinit eval counts if top_k has changed * Incorporate reviewer feedback
This commit is contained in:
parent
c513865566
commit
59e3c55c47
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Tuple, Dict, Any
|
from typing import List, Tuple, Dict, Any, Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from haystack import MultiLabel, Label
|
from haystack import MultiLabel, Label
|
||||||
@ -18,7 +18,7 @@ class EvalDocuments:
|
|||||||
a look at our evaluation tutorial for more info about open vs closed domain eval (
|
a look at our evaluation tutorial for more info about open vs closed domain eval (
|
||||||
https://haystack.deepset.ai/docs/latest/tutorial5md).
|
https://haystack.deepset.ai/docs/latest/tutorial5md).
|
||||||
"""
|
"""
|
||||||
def __init__(self, debug: bool=False, open_domain: bool=True, top_k: int=10, name="EvalDocuments"):
|
def __init__(self, debug: bool=False, open_domain: bool=True, top_k_eval_documents: int=10, name="EvalDocuments"):
|
||||||
"""
|
"""
|
||||||
:param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it.
|
:param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it.
|
||||||
When False, correct retrieval is evaluated based on document_id.
|
When False, correct retrieval is evaluated based on document_id.
|
||||||
@ -31,8 +31,10 @@ class EvalDocuments:
|
|||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.log: List = []
|
self.log: List = []
|
||||||
self.open_domain = open_domain
|
self.open_domain = open_domain
|
||||||
self.top_k = top_k
|
self.top_k_eval_documents = top_k_eval_documents
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.too_few_docs_warning = False
|
||||||
|
self.top_k_used = 0
|
||||||
|
|
||||||
def init_counts(self):
|
def init_counts(self):
|
||||||
self.correct_retrieval_count = 0
|
self.correct_retrieval_count = 0
|
||||||
@ -47,10 +49,26 @@ class EvalDocuments:
|
|||||||
self.reciprocal_rank_sum = 0.0
|
self.reciprocal_rank_sum = 0.0
|
||||||
self.has_answer_reciprocal_rank_sum = 0.0
|
self.has_answer_reciprocal_rank_sum = 0.0
|
||||||
|
|
||||||
def run(self, documents, labels: dict, **kwargs):
|
def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None, **kwargs):
|
||||||
"""Run this node on one sample and its labels"""
|
"""Run this node on one sample and its labels"""
|
||||||
self.query_count += 1
|
self.query_count += 1
|
||||||
retriever_labels = get_label(labels, kwargs["node_id"])
|
retriever_labels = get_label(labels, kwargs["node_id"])
|
||||||
|
if not top_k_eval_documents:
|
||||||
|
top_k_eval_documents = self.top_k_eval_documents
|
||||||
|
|
||||||
|
if not self.top_k_used:
|
||||||
|
self.top_k_used = top_k_eval_documents
|
||||||
|
elif self.top_k_used != top_k_eval_documents:
|
||||||
|
logger.warning(f"EvalDocuments was last run with top_k_eval_documents={self.top_k_used} but is "
|
||||||
|
f"being run again with top_k_eval_documents={self.top_k_eval_documents}. "
|
||||||
|
f"The evaluation counter is being reset from this point so that the evaluation "
|
||||||
|
f"metrics are interpretable.")
|
||||||
|
self.init_counts()
|
||||||
|
|
||||||
|
if len(documents) < top_k_eval_documents and not self.too_few_docs_warning:
|
||||||
|
logger.warning(f"EvalDocuments is being provided less candidate documents than top_k_eval_documents "
|
||||||
|
f"(currently set to {top_k_eval_documents}).")
|
||||||
|
self.too_few_docs_warning = True
|
||||||
|
|
||||||
# TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object
|
# TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object
|
||||||
# If this sample is impossible to answer and expects a no_answer response
|
# If this sample is impossible to answer and expects a no_answer response
|
||||||
@ -67,7 +85,7 @@ class EvalDocuments:
|
|||||||
# If there are answer span annotations in the labels
|
# If there are answer span annotations in the labels
|
||||||
else:
|
else:
|
||||||
self.has_answer_count += 1
|
self.has_answer_count += 1
|
||||||
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents)
|
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents, top_k_eval_documents)
|
||||||
self.reciprocal_rank_sum += retrieved_reciprocal_rank
|
self.reciprocal_rank_sum += retrieved_reciprocal_rank
|
||||||
correct_retrieval = True if retrieved_reciprocal_rank > 0 else False
|
correct_retrieval = True if retrieved_reciprocal_rank > 0 else False
|
||||||
self.has_answer_correct += int(correct_retrieval)
|
self.has_answer_correct += int(correct_retrieval)
|
||||||
@ -79,6 +97,8 @@ class EvalDocuments:
|
|||||||
self.recall = self.correct_retrieval_count / self.query_count
|
self.recall = self.correct_retrieval_count / self.query_count
|
||||||
self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count
|
self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count
|
||||||
|
|
||||||
|
self.top_k_used = top_k_eval_documents
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs})
|
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs})
|
||||||
return {"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs}, "output_1"
|
return {"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs}, "output_1"
|
||||||
@ -86,15 +106,15 @@ class EvalDocuments:
|
|||||||
def is_correctly_retrieved(self, retriever_labels, predictions):
|
def is_correctly_retrieved(self, retriever_labels, predictions):
|
||||||
return self.reciprocal_rank_retrieved(retriever_labels, predictions) > 0
|
return self.reciprocal_rank_retrieved(retriever_labels, predictions) > 0
|
||||||
|
|
||||||
def reciprocal_rank_retrieved(self, retriever_labels, predictions):
|
def reciprocal_rank_retrieved(self, retriever_labels, predictions, top_k_eval_documents):
|
||||||
if self.open_domain:
|
if self.open_domain:
|
||||||
for label in retriever_labels.multiple_answers:
|
for label in retriever_labels.multiple_answers:
|
||||||
for rank, p in enumerate(predictions[:self.top_k]):
|
for rank, p in enumerate(predictions[:top_k_eval_documents]):
|
||||||
if label.lower() in p.text.lower():
|
if label.lower() in p.text.lower():
|
||||||
return 1/(rank+1)
|
return 1/(rank+1)
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
prediction_ids = [p.id for p in predictions[:self.top_k]]
|
prediction_ids = [p.id for p in predictions[:top_k_eval_documents]]
|
||||||
label_ids = retriever_labels.multiple_document_ids
|
label_ids = retriever_labels.multiple_document_ids
|
||||||
for rank, p in enumerate(prediction_ids):
|
for rank, p in enumerate(prediction_ids):
|
||||||
if p in label_ids:
|
if p in label_ids:
|
||||||
@ -107,15 +127,15 @@ class EvalDocuments:
|
|||||||
print("-----------------")
|
print("-----------------")
|
||||||
if self.no_answer_count:
|
if self.no_answer_count:
|
||||||
print(
|
print(
|
||||||
f"has_answer recall@{self.top_k}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})")
|
f"has_answer recall@{self.top_k_used}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})")
|
||||||
print(
|
print(
|
||||||
f"no_answer recall@{self.top_k}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)")
|
f"no_answer recall@{self.top_k_used}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)")
|
||||||
print(
|
print(
|
||||||
f"has_answer mean_reciprocal_rank@{self.top_k}: {self.has_answer_mean_reciprocal_rank:.4f}")
|
f"has_answer mean_reciprocal_rank@{self.top_k_used}: {self.has_answer_mean_reciprocal_rank:.4f}")
|
||||||
print(
|
print(
|
||||||
f"no_answer mean_reciprocal_rank@{self.top_k}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)")
|
f"no_answer mean_reciprocal_rank@{self.top_k_used}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)")
|
||||||
print(f"recall@{self.top_k}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})")
|
print(f"recall@{self.top_k_used}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})")
|
||||||
print(f"mean_reciprocal_rank@{self.top_k}: {self.mean_reciprocal_rank:.4f}")
|
print(f"mean_reciprocal_rank@{self.top_k_used}: {self.mean_reciprocal_rank:.4f}")
|
||||||
|
|
||||||
|
|
||||||
class EvalAnswers:
|
class EvalAnswers:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user