mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 02:09:39 +00:00
Fix retriever evaluation metrics (#547)
* Add mean reciprocal rank and fix mean average precision * Add mrr metric to docstring * Fix mypy error
This commit is contained in:
parent
53be92c155
commit
ffaa0249f7
@ -33,22 +33,33 @@ def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def calculate_average_precision(questions_with_docs: List[dict]):
|
def calculate_average_precision_and_reciprocal_rank(questions_with_docs: List[dict]):
|
||||||
questions_with_correct_doc = []
|
questions_with_correct_doc = []
|
||||||
summed_avg_precision_retriever = 0.0
|
summed_avg_precision_retriever = 0.0
|
||||||
|
summed_reciprocal_rank_retriever = 0.0
|
||||||
|
|
||||||
for question in questions_with_docs:
|
for question in questions_with_docs:
|
||||||
|
number_relevant_docs = len(set(question["question"].multiple_document_ids))
|
||||||
|
found_relevant_doc = False
|
||||||
|
relevant_docs_found = 0
|
||||||
for doc_idx, doc in enumerate(question["docs"]):
|
for doc_idx, doc in enumerate(question["docs"]):
|
||||||
# check if correct doc among retrieved docs
|
# check if correct doc among retrieved docs
|
||||||
if doc.id in question["question"].multiple_document_ids:
|
if doc.id in question["question"].multiple_document_ids:
|
||||||
summed_avg_precision_retriever += 1 / (doc_idx + 1)
|
if not found_relevant_doc:
|
||||||
questions_with_correct_doc.append({
|
summed_reciprocal_rank_retriever += 1 / (doc_idx + 1)
|
||||||
"question": question["question"],
|
relevant_docs_found += 1
|
||||||
"docs": question["docs"]
|
found_relevant_doc = True
|
||||||
})
|
summed_avg_precision_retriever += (1 / number_relevant_docs) * (relevant_docs_found / (doc_idx + 1))
|
||||||
break
|
if relevant_docs_found == number_relevant_docs:
|
||||||
|
break
|
||||||
|
|
||||||
return questions_with_correct_doc, summed_avg_precision_retriever
|
if found_relevant_doc:
|
||||||
|
questions_with_correct_doc.append({
|
||||||
|
"question": question["question"],
|
||||||
|
"docs": question["docs"]
|
||||||
|
})
|
||||||
|
|
||||||
|
return questions_with_correct_doc, summed_avg_precision_retriever, summed_reciprocal_rank_retriever
|
||||||
|
|
||||||
|
|
||||||
def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]):
|
def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]):
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from collections import defaultdict
|
|||||||
from haystack.reader.base import BaseReader
|
from haystack.reader.base import BaseReader
|
||||||
from haystack.retriever.base import BaseRetriever
|
from haystack.retriever.base import BaseRetriever
|
||||||
from haystack import MultiLabel
|
from haystack import MultiLabel
|
||||||
from haystack.eval import calculate_average_precision, eval_counts_reader_batch, calculate_reader_metrics, \
|
from haystack.eval import calculate_average_precision_and_reciprocal_rank, eval_counts_reader_batch, \
|
||||||
eval_counts_reader
|
calculate_reader_metrics, eval_counts_reader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -131,7 +131,9 @@ class Finder:
|
|||||||
Returns a dict containing the following metrics:
|
Returns a dict containing the following metrics:
|
||||||
- ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents
|
- ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents
|
||||||
- ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant
|
- ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant
|
||||||
documents a higher rank.
|
documents a higher rank. Considers all retrieved relevant documents.
|
||||||
|
- ``"retriever_mrr"``: Mean of reciprocal rank for each question. Rewards retrievers that give relevant
|
||||||
|
documents a higher rank. Only considers the highest ranked relevant document.
|
||||||
- ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer
|
- ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer
|
||||||
- ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap
|
- ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap
|
||||||
with corresponding correct answer for answerable questions
|
with corresponding correct answer for answerable questions
|
||||||
@ -193,17 +195,28 @@ class Finder:
|
|||||||
single_retrieve_start = time.time()
|
single_retrieve_start = time.time()
|
||||||
retrieved_docs = self.retriever.retrieve(question_string, top_k=top_k_retriever, index=doc_index)
|
retrieved_docs = self.retriever.retrieve(question_string, top_k=top_k_retriever, index=doc_index)
|
||||||
retrieve_times.append(time.time() - single_retrieve_start)
|
retrieve_times.append(time.time() - single_retrieve_start)
|
||||||
|
number_relevant_docs = len(set(question.multiple_document_ids))
|
||||||
|
|
||||||
# check if correct doc among retrieved docs
|
# check if correct doc among retrieved docs
|
||||||
|
found_relevant_doc = False
|
||||||
|
relevant_docs_found = 0
|
||||||
for doc_idx, doc in enumerate(retrieved_docs):
|
for doc_idx, doc in enumerate(retrieved_docs):
|
||||||
if doc.id in question.multiple_document_ids:
|
if doc.id in question.multiple_document_ids:
|
||||||
counts["correct_retrievals"] += 1
|
relevant_docs_found += 1
|
||||||
counts["summed_avg_precision_retriever"] += 1 / (doc_idx + 1)
|
if not found_relevant_doc:
|
||||||
questions_with_docs.append({
|
counts["correct_retrievals"] += 1
|
||||||
"question": question,
|
counts["summed_reciprocal_rank_retriever"] += 1 / (doc_idx + 1)
|
||||||
"docs": retrieved_docs
|
counts["summed_avg_precision_retriever"] += (1 / number_relevant_docs) \
|
||||||
})
|
* (relevant_docs_found / (doc_idx + 1))
|
||||||
break
|
found_relevant_doc = True
|
||||||
|
if relevant_docs_found == number_relevant_docs:
|
||||||
|
break
|
||||||
|
|
||||||
|
if found_relevant_doc:
|
||||||
|
questions_with_docs.append({
|
||||||
|
"question": question,
|
||||||
|
"docs": retrieved_docs
|
||||||
|
})
|
||||||
|
|
||||||
retriever_total_time = time.time() - retriever_start_time
|
retriever_total_time = time.time() - retriever_start_time
|
||||||
counts["number_of_questions"] = q_idx + 1
|
counts["number_of_questions"] = q_idx + 1
|
||||||
@ -270,7 +283,9 @@ class Finder:
|
|||||||
Returns a dict containing the following metrics:
|
Returns a dict containing the following metrics:
|
||||||
- ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents
|
- ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents
|
||||||
- ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant
|
- ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant
|
||||||
documents a higher rank.
|
documents a higher rank. Considers all retrieved relevant documents.
|
||||||
|
- ``"retriever_mrr"``: Mean of reciprocal rank for each question. Rewards retrievers that give relevant
|
||||||
|
documents a higher rank. Only considers the highest ranked relevant document.
|
||||||
- ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer
|
- ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer
|
||||||
- ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap
|
- ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap
|
||||||
with corresponding correct answer for answerable questions
|
with corresponding correct answer for answerable questions
|
||||||
@ -330,7 +345,10 @@ class Finder:
|
|||||||
questions_with_docs = self._retrieve_docs(questions, top_k=top_k_retriever, doc_index=doc_index)
|
questions_with_docs = self._retrieve_docs(questions, top_k=top_k_retriever, doc_index=doc_index)
|
||||||
retriever_total_time = time.time() - retriever_start_time
|
retriever_total_time = time.time() - retriever_start_time
|
||||||
|
|
||||||
questions_with_correct_doc, summed_avg_precision_retriever = calculate_average_precision(questions_with_docs)
|
questions_with_correct_doc, \
|
||||||
|
summed_avg_precision_retriever, \
|
||||||
|
summed_reciprocal_rank_retriever = calculate_average_precision_and_reciprocal_rank(questions_with_docs)
|
||||||
|
|
||||||
correct_retrievals = len(questions_with_correct_doc)
|
correct_retrievals = len(questions_with_correct_doc)
|
||||||
|
|
||||||
# extract answers
|
# extract answers
|
||||||
@ -349,6 +367,7 @@ class Finder:
|
|||||||
results = calculate_reader_metrics(counts, correct_retrievals)
|
results = calculate_reader_metrics(counts, correct_retrievals)
|
||||||
results["retriever_recall"] = correct_retrievals / number_of_questions
|
results["retriever_recall"] = correct_retrievals / number_of_questions
|
||||||
results["retriever_map"] = summed_avg_precision_retriever / number_of_questions
|
results["retriever_map"] = summed_avg_precision_retriever / number_of_questions
|
||||||
|
results["retriever_mrr"] = summed_reciprocal_rank_retriever / number_of_questions
|
||||||
results["total_retrieve_time"] = retriever_total_time
|
results["total_retrieve_time"] = retriever_total_time
|
||||||
results["avg_retrieve_time"] = retriever_total_time / number_of_questions
|
results["avg_retrieve_time"] = retriever_total_time / number_of_questions
|
||||||
results["total_reader_time"] = reader_total_time
|
results["total_reader_time"] = reader_total_time
|
||||||
@ -389,6 +408,7 @@ class Finder:
|
|||||||
print("\n___Retriever Metrics in Finder___")
|
print("\n___Retriever Metrics in Finder___")
|
||||||
print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}")
|
print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}")
|
||||||
print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}")
|
print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}")
|
||||||
|
print(f"Retriever Mean Reciprocal Rank: {finder_eval_results['retriever_mrr']:.3f}")
|
||||||
|
|
||||||
# Reader is only evaluated with those questions, where the correct document is among the retrieved ones
|
# Reader is only evaluated with those questions, where the correct document is among the retrieved ones
|
||||||
print("\n___Reader Metrics in Finder___")
|
print("\n___Reader Metrics in Finder___")
|
||||||
@ -430,6 +450,7 @@ class Finder:
|
|||||||
|
|
||||||
eval_results["retriever_recall"] = eval_counts["correct_retrievals"] / number_of_questions
|
eval_results["retriever_recall"] = eval_counts["correct_retrievals"] / number_of_questions
|
||||||
eval_results["retriever_map"] = eval_counts["summed_avg_precision_retriever"] / number_of_questions
|
eval_results["retriever_map"] = eval_counts["summed_avg_precision_retriever"] / number_of_questions
|
||||||
|
eval_results["retriever_mrr"] = eval_counts["summed_reciprocal_rank_retriever"] / number_of_questions
|
||||||
|
|
||||||
eval_results["reader_top1_accuracy"] = eval_counts["correct_readings_top1"] / correct_retrievals
|
eval_results["reader_top1_accuracy"] = eval_counts["correct_readings_top1"] / correct_retrievals
|
||||||
eval_results["reader_top1_accuracy_has_answer"] = eval_counts["correct_readings_top1_has_answer"] / number_of_has_answer
|
eval_results["reader_top1_accuracy_has_answer"] = eval_counts["correct_readings_top1_has_answer"] / number_of_has_answer
|
||||||
|
|||||||
@ -56,8 +56,10 @@ class BaseRetriever(ABC):
|
|||||||
| Returns a dict containing the following metrics:
|
| Returns a dict containing the following metrics:
|
||||||
|
|
||||||
- "recall": Proportion of questions for which correct document is among retrieved documents
|
- "recall": Proportion of questions for which correct document is among retrieved documents
|
||||||
- "mean avg precision": Mean of average precision for each question. Rewards retrievers that give relevant
|
- "mrr": Mean of reciprocal rank. Rewards retrievers that give relevant documents a higher rank.
|
||||||
documents a higher rank.
|
Only considers the highest ranked relevant document.
|
||||||
|
- "map": Mean of average precision for each question. Rewards retrievers that give relevant
|
||||||
|
documents a higher rank. Considers all retrieved relevant documents. (only with ``open_domain=False``)
|
||||||
|
|
||||||
:param label_index: Index/Table in DocumentStore where labeled questions are stored
|
:param label_index: Index/Table in DocumentStore where labeled questions are stored
|
||||||
:param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored
|
:param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored
|
||||||
@ -78,7 +80,8 @@ class BaseRetriever(ABC):
|
|||||||
labels = self.document_store.get_all_labels_aggregated(index=label_index, filters=filters)
|
labels = self.document_store.get_all_labels_aggregated(index=label_index, filters=filters)
|
||||||
|
|
||||||
correct_retrievals = 0
|
correct_retrievals = 0
|
||||||
summed_avg_precision = 0
|
summed_avg_precision = 0.0
|
||||||
|
summed_reciprocal_rank = 0.0
|
||||||
|
|
||||||
# Collect questions and corresponding answers/document_ids in a dict
|
# Collect questions and corresponding answers/document_ids in a dict
|
||||||
question_label_dict = {}
|
question_label_dict = {}
|
||||||
@ -99,12 +102,18 @@ class BaseRetriever(ABC):
|
|||||||
if return_preds:
|
if return_preds:
|
||||||
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
||||||
# check if correct doc in retrieved docs
|
# check if correct doc in retrieved docs
|
||||||
|
found_relevant_doc = False
|
||||||
for doc_idx, doc in enumerate(retrieved_docs):
|
for doc_idx, doc in enumerate(retrieved_docs):
|
||||||
for gold_answer in gold_answers:
|
for gold_answer in gold_answers:
|
||||||
if gold_answer in doc.text:
|
if gold_answer in doc.text:
|
||||||
correct_retrievals += 1
|
if not found_relevant_doc:
|
||||||
summed_avg_precision += 1 / (doc_idx + 1) # type: ignore
|
correct_retrievals += 1
|
||||||
|
summed_reciprocal_rank += 1 / (doc_idx + 1)
|
||||||
|
found_relevant_doc = True
|
||||||
break
|
break
|
||||||
|
# For the metrics in the open-domain case we are only considering the highest ranked relevant doc
|
||||||
|
if found_relevant_doc:
|
||||||
|
break
|
||||||
# Option 2: Strict evaluation by document ids that are listed in the labels
|
# Option 2: Strict evaluation by document ids that are listed in the labels
|
||||||
else:
|
else:
|
||||||
for question, gold_ids in tqdm(question_label_dict.items()):
|
for question, gold_ids in tqdm(question_label_dict.items()):
|
||||||
@ -112,28 +121,38 @@ class BaseRetriever(ABC):
|
|||||||
if return_preds:
|
if return_preds:
|
||||||
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
||||||
# check if correct doc in retrieved docs
|
# check if correct doc in retrieved docs
|
||||||
|
relevant_docs_found = 0
|
||||||
|
found_relevant_doc = False
|
||||||
for doc_idx, doc in enumerate(retrieved_docs):
|
for doc_idx, doc in enumerate(retrieved_docs):
|
||||||
for gold_id in gold_ids:
|
for gold_id in gold_ids:
|
||||||
if str(doc.id) == gold_id:
|
if str(doc.id) == gold_id:
|
||||||
correct_retrievals += 1
|
if not found_relevant_doc:
|
||||||
summed_avg_precision += 1 / (doc_idx + 1) # type: ignore
|
correct_retrievals += 1
|
||||||
|
summed_reciprocal_rank += 1 / (doc_idx + 1)
|
||||||
|
found_relevant_doc = True
|
||||||
|
relevant_docs_found += 1
|
||||||
|
summed_avg_precision += (1 / len(gold_ids)) * (relevant_docs_found / (doc_idx + 1))
|
||||||
break
|
break
|
||||||
# Metrics
|
# Metrics
|
||||||
number_of_questions = len(question_label_dict)
|
number_of_questions = len(question_label_dict)
|
||||||
recall = correct_retrievals / number_of_questions
|
recall = correct_retrievals / number_of_questions
|
||||||
mean_avg_precision = summed_avg_precision / number_of_questions
|
mean_reciprocal_rank = summed_reciprocal_rank / number_of_questions
|
||||||
|
|
||||||
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
|
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
|
||||||
f" the top-{top_k} candidate passages selected by the retriever."))
|
f" the top-{top_k} candidate passages selected by the retriever."))
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
"recall": recall,
|
"recall": recall,
|
||||||
"map": mean_avg_precision,
|
"mrr": mean_reciprocal_rank,
|
||||||
"retrieve_time": self.retrieve_time,
|
"retrieve_time": self.retrieve_time,
|
||||||
"n_questions": number_of_questions,
|
"n_questions": number_of_questions,
|
||||||
"top_k": top_k
|
"top_k": top_k
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not open_domain:
|
||||||
|
mean_avg_precision = summed_avg_precision / number_of_questions
|
||||||
|
metrics["map"] = mean_avg_precision
|
||||||
|
|
||||||
if return_preds:
|
if return_preds:
|
||||||
return {"metrics": metrics, "predictions": predictions}
|
return {"metrics": metrics, "predictions": predictions}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -75,7 +75,9 @@ def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain,
|
|||||||
# eval retriever
|
# eval retriever
|
||||||
results = retriever.eval(top_k=1, label_index="test_feedback", doc_index="test_eval_document", open_domain=open_domain)
|
results = retriever.eval(top_k=1, label_index="test_feedback", doc_index="test_eval_document", open_domain=open_domain)
|
||||||
assert results["recall"] == 1.0
|
assert results["recall"] == 1.0
|
||||||
assert results["map"] == 1.0
|
assert results["mrr"] == 1.0
|
||||||
|
if not open_domain:
|
||||||
|
assert results["map"] == 1.0
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
document_store.delete_all_documents(index="test_eval_document")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user