mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 17:59:27 +00:00 
			
		
		
		
	Fix retriever evaluation metrics (#547)
* Add mean reciprocal rank and fix mean average precision * Add mrr metric to docstring * Fix mypy error
This commit is contained in:
		
							parent
							
								
									53be92c155
								
							
						
					
					
						commit
						ffaa0249f7
					
				| @ -33,22 +33,33 @@ def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals | |||||||
|     return metrics |     return metrics | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def calculate_average_precision(questions_with_docs: List[dict]): | def calculate_average_precision_and_reciprocal_rank(questions_with_docs: List[dict]): | ||||||
|     questions_with_correct_doc = [] |     questions_with_correct_doc = [] | ||||||
|     summed_avg_precision_retriever = 0.0 |     summed_avg_precision_retriever = 0.0 | ||||||
|  |     summed_reciprocal_rank_retriever = 0.0 | ||||||
| 
 | 
 | ||||||
|     for question in questions_with_docs: |     for question in questions_with_docs: | ||||||
|  |         number_relevant_docs = len(set(question["question"].multiple_document_ids)) | ||||||
|  |         found_relevant_doc = False | ||||||
|  |         relevant_docs_found = 0 | ||||||
|         for doc_idx, doc in enumerate(question["docs"]): |         for doc_idx, doc in enumerate(question["docs"]): | ||||||
|             # check if correct doc among retrieved docs |             # check if correct doc among retrieved docs | ||||||
|             if doc.id in question["question"].multiple_document_ids: |             if doc.id in question["question"].multiple_document_ids: | ||||||
|                 summed_avg_precision_retriever += 1 / (doc_idx + 1) |                 if not found_relevant_doc: | ||||||
|                 questions_with_correct_doc.append({ |                     summed_reciprocal_rank_retriever += 1 / (doc_idx + 1) | ||||||
|                     "question": question["question"], |                 relevant_docs_found += 1 | ||||||
|                     "docs": question["docs"] |                 found_relevant_doc = True | ||||||
|                 }) |                 summed_avg_precision_retriever += (1 / number_relevant_docs) * (relevant_docs_found / (doc_idx + 1)) | ||||||
|                 break |                 if relevant_docs_found == number_relevant_docs: | ||||||
|  |                     break | ||||||
| 
 | 
 | ||||||
|     return questions_with_correct_doc, summed_avg_precision_retriever |         if found_relevant_doc: | ||||||
|  |             questions_with_correct_doc.append({ | ||||||
|  |                 "question": question["question"], | ||||||
|  |                 "docs": question["docs"] | ||||||
|  |             }) | ||||||
|  | 
 | ||||||
|  |     return questions_with_correct_doc, summed_avg_precision_retriever, summed_reciprocal_rank_retriever | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]): | def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]): | ||||||
|  | |||||||
| @ -8,8 +8,8 @@ from collections import defaultdict | |||||||
| from haystack.reader.base import BaseReader | from haystack.reader.base import BaseReader | ||||||
| from haystack.retriever.base import BaseRetriever | from haystack.retriever.base import BaseRetriever | ||||||
| from haystack import MultiLabel | from haystack import MultiLabel | ||||||
| from haystack.eval import calculate_average_precision, eval_counts_reader_batch, calculate_reader_metrics, \ | from haystack.eval import calculate_average_precision_and_reciprocal_rank, eval_counts_reader_batch, \ | ||||||
|     eval_counts_reader |     calculate_reader_metrics, eval_counts_reader | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| @ -131,7 +131,9 @@ class Finder: | |||||||
|         Returns a dict containing the following metrics: |         Returns a dict containing the following metrics: | ||||||
|             - ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents |             - ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents | ||||||
|             - ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant |             - ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant | ||||||
|               documents a higher rank. |               documents a higher rank. Considers all retrieved relevant documents. | ||||||
|  |             - ``"retriever_mrr"``: Mean of reciprocal rank for each question. Rewards retrievers that give relevant | ||||||
|  |               documents a higher rank. Only considers the highest ranked relevant document. | ||||||
|             - ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer |             - ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer | ||||||
|             - ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap |             - ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap | ||||||
|               with corresponding correct answer for answerable questions |               with corresponding correct answer for answerable questions | ||||||
| @ -193,17 +195,28 @@ class Finder: | |||||||
|             single_retrieve_start = time.time() |             single_retrieve_start = time.time() | ||||||
|             retrieved_docs = self.retriever.retrieve(question_string, top_k=top_k_retriever, index=doc_index) |             retrieved_docs = self.retriever.retrieve(question_string, top_k=top_k_retriever, index=doc_index) | ||||||
|             retrieve_times.append(time.time() - single_retrieve_start) |             retrieve_times.append(time.time() - single_retrieve_start) | ||||||
|  |             number_relevant_docs = len(set(question.multiple_document_ids)) | ||||||
| 
 | 
 | ||||||
|             # check if correct doc among retrieved docs |             # check if correct doc among retrieved docs | ||||||
|  |             found_relevant_doc = False | ||||||
|  |             relevant_docs_found = 0 | ||||||
|             for doc_idx, doc in enumerate(retrieved_docs): |             for doc_idx, doc in enumerate(retrieved_docs): | ||||||
|                 if doc.id in question.multiple_document_ids: |                 if doc.id in question.multiple_document_ids: | ||||||
|                     counts["correct_retrievals"] += 1 |                     relevant_docs_found += 1 | ||||||
|                     counts["summed_avg_precision_retriever"] += 1 / (doc_idx + 1) |                     if not found_relevant_doc: | ||||||
|                     questions_with_docs.append({ |                         counts["correct_retrievals"] += 1 | ||||||
|                         "question": question, |                         counts["summed_reciprocal_rank_retriever"] += 1 / (doc_idx + 1) | ||||||
|                         "docs": retrieved_docs |                     counts["summed_avg_precision_retriever"] += (1 / number_relevant_docs) \ | ||||||
|                     }) |                                                                 * (relevant_docs_found / (doc_idx + 1)) | ||||||
|                     break |                     found_relevant_doc = True | ||||||
|  |                     if relevant_docs_found == number_relevant_docs: | ||||||
|  |                         break | ||||||
|  | 
 | ||||||
|  |             if found_relevant_doc: | ||||||
|  |                 questions_with_docs.append({ | ||||||
|  |                     "question": question, | ||||||
|  |                     "docs": retrieved_docs | ||||||
|  |                 }) | ||||||
| 
 | 
 | ||||||
|         retriever_total_time = time.time() - retriever_start_time |         retriever_total_time = time.time() - retriever_start_time | ||||||
|         counts["number_of_questions"] = q_idx + 1 |         counts["number_of_questions"] = q_idx + 1 | ||||||
| @ -270,7 +283,9 @@ class Finder: | |||||||
|         Returns a dict containing the following metrics: |         Returns a dict containing the following metrics: | ||||||
|             - ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents |             - ``"retriever_recall"``: Proportion of questions for which correct document is among retrieved documents | ||||||
|             - ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant |             - ``"retriever_map"``: Mean of average precision for each question. Rewards retrievers that give relevant | ||||||
|               documents a higher rank. |               documents a higher rank. Considers all retrieved relevant documents. | ||||||
|  |             - ``"retriever_mrr"``: Mean of reciprocal rank for each question. Rewards retrievers that give relevant | ||||||
|  |               documents a higher rank. Only considers the highest ranked relevant document. | ||||||
|             - ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer |             - ``"reader_top1_accuracy"``: Proportion of highest ranked predicted answers that overlap with corresponding correct answer | ||||||
|             - ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap |             - ``"reader_top1_accuracy_has_answer"``: Proportion of highest ranked predicted answers that overlap | ||||||
|               with corresponding correct answer for answerable questions |               with corresponding correct answer for answerable questions | ||||||
| @ -330,7 +345,10 @@ class Finder: | |||||||
|         questions_with_docs = self._retrieve_docs(questions, top_k=top_k_retriever, doc_index=doc_index) |         questions_with_docs = self._retrieve_docs(questions, top_k=top_k_retriever, doc_index=doc_index) | ||||||
|         retriever_total_time = time.time() - retriever_start_time |         retriever_total_time = time.time() - retriever_start_time | ||||||
| 
 | 
 | ||||||
|         questions_with_correct_doc, summed_avg_precision_retriever = calculate_average_precision(questions_with_docs) |         questions_with_correct_doc, \ | ||||||
|  |         summed_avg_precision_retriever, \ | ||||||
|  |         summed_reciprocal_rank_retriever = calculate_average_precision_and_reciprocal_rank(questions_with_docs) | ||||||
|  | 
 | ||||||
|         correct_retrievals = len(questions_with_correct_doc) |         correct_retrievals = len(questions_with_correct_doc) | ||||||
| 
 | 
 | ||||||
|         # extract answers |         # extract answers | ||||||
| @ -349,6 +367,7 @@ class Finder: | |||||||
|         results = calculate_reader_metrics(counts, correct_retrievals) |         results = calculate_reader_metrics(counts, correct_retrievals) | ||||||
|         results["retriever_recall"] = correct_retrievals / number_of_questions |         results["retriever_recall"] = correct_retrievals / number_of_questions | ||||||
|         results["retriever_map"] = summed_avg_precision_retriever / number_of_questions |         results["retriever_map"] = summed_avg_precision_retriever / number_of_questions | ||||||
|  |         results["retriever_mrr"] = summed_reciprocal_rank_retriever / number_of_questions | ||||||
|         results["total_retrieve_time"] = retriever_total_time |         results["total_retrieve_time"] = retriever_total_time | ||||||
|         results["avg_retrieve_time"] = retriever_total_time / number_of_questions |         results["avg_retrieve_time"] = retriever_total_time / number_of_questions | ||||||
|         results["total_reader_time"] = reader_total_time |         results["total_reader_time"] = reader_total_time | ||||||
| @ -389,6 +408,7 @@ class Finder: | |||||||
|         print("\n___Retriever Metrics in Finder___") |         print("\n___Retriever Metrics in Finder___") | ||||||
|         print(f"Retriever Recall            : {finder_eval_results['retriever_recall']:.3f}") |         print(f"Retriever Recall            : {finder_eval_results['retriever_recall']:.3f}") | ||||||
|         print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}") |         print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}") | ||||||
|  |         print(f"Retriever Mean Reciprocal Rank: {finder_eval_results['retriever_mrr']:.3f}") | ||||||
| 
 | 
 | ||||||
|         # Reader is only evaluated with those questions, where the correct document is among the retrieved ones |         # Reader is only evaluated with those questions, where the correct document is among the retrieved ones | ||||||
|         print("\n___Reader Metrics in Finder___") |         print("\n___Reader Metrics in Finder___") | ||||||
| @ -430,6 +450,7 @@ class Finder: | |||||||
| 
 | 
 | ||||||
|         eval_results["retriever_recall"] = eval_counts["correct_retrievals"] / number_of_questions |         eval_results["retriever_recall"] = eval_counts["correct_retrievals"] / number_of_questions | ||||||
|         eval_results["retriever_map"] = eval_counts["summed_avg_precision_retriever"] / number_of_questions |         eval_results["retriever_map"] = eval_counts["summed_avg_precision_retriever"] / number_of_questions | ||||||
|  |         eval_results["retriever_mrr"] = eval_counts["summed_reciprocal_rank_retriever"] / number_of_questions | ||||||
| 
 | 
 | ||||||
|         eval_results["reader_top1_accuracy"] = eval_counts["correct_readings_top1"] / correct_retrievals |         eval_results["reader_top1_accuracy"] = eval_counts["correct_readings_top1"] / correct_retrievals | ||||||
|         eval_results["reader_top1_accuracy_has_answer"] = eval_counts["correct_readings_top1_has_answer"] / number_of_has_answer |         eval_results["reader_top1_accuracy_has_answer"] = eval_counts["correct_readings_top1_has_answer"] / number_of_has_answer | ||||||
|  | |||||||
| @ -56,8 +56,10 @@ class BaseRetriever(ABC): | |||||||
|         |  Returns a dict containing the following metrics: |         |  Returns a dict containing the following metrics: | ||||||
| 
 | 
 | ||||||
|             - "recall": Proportion of questions for which correct document is among retrieved documents |             - "recall": Proportion of questions for which correct document is among retrieved documents | ||||||
|             - "mean avg precision": Mean of average precision for each question. Rewards retrievers that give relevant |             - "mrr": Mean of reciprocal rank. Rewards retrievers that give relevant documents a higher rank. | ||||||
|               documents a higher rank. |               Only considers the highest ranked relevant document. | ||||||
|  |             - "map": Mean of average precision for each question. Rewards retrievers that give relevant | ||||||
|  |               documents a higher rank. Considers all retrieved relevant documents. (only with ``open_domain=False``) | ||||||
| 
 | 
 | ||||||
|         :param label_index: Index/Table in DocumentStore where labeled questions are stored |         :param label_index: Index/Table in DocumentStore where labeled questions are stored | ||||||
|         :param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored |         :param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored | ||||||
| @ -78,7 +80,8 @@ class BaseRetriever(ABC): | |||||||
|         labels = self.document_store.get_all_labels_aggregated(index=label_index, filters=filters) |         labels = self.document_store.get_all_labels_aggregated(index=label_index, filters=filters) | ||||||
| 
 | 
 | ||||||
|         correct_retrievals = 0 |         correct_retrievals = 0 | ||||||
|         summed_avg_precision = 0 |         summed_avg_precision = 0.0 | ||||||
|  |         summed_reciprocal_rank = 0.0 | ||||||
| 
 | 
 | ||||||
|         # Collect questions and corresponding answers/document_ids in a dict |         # Collect questions and corresponding answers/document_ids in a dict | ||||||
|         question_label_dict = {} |         question_label_dict = {} | ||||||
| @ -99,12 +102,18 @@ class BaseRetriever(ABC): | |||||||
|                 if return_preds: |                 if return_preds: | ||||||
|                     predictions.append({"question": question, "retrieved_docs": retrieved_docs}) |                     predictions.append({"question": question, "retrieved_docs": retrieved_docs}) | ||||||
|                 # check if correct doc in retrieved docs |                 # check if correct doc in retrieved docs | ||||||
|  |                 found_relevant_doc = False | ||||||
|                 for doc_idx, doc in enumerate(retrieved_docs): |                 for doc_idx, doc in enumerate(retrieved_docs): | ||||||
|                     for gold_answer in gold_answers: |                     for gold_answer in gold_answers: | ||||||
|                         if gold_answer in doc.text: |                         if gold_answer in doc.text: | ||||||
|                             correct_retrievals += 1 |                             if not found_relevant_doc: | ||||||
|                             summed_avg_precision += 1 / (doc_idx + 1)  # type: ignore |                                 correct_retrievals += 1 | ||||||
|  |                                 summed_reciprocal_rank += 1 / (doc_idx + 1) | ||||||
|  |                             found_relevant_doc = True | ||||||
|                             break |                             break | ||||||
|  |                     # For the metrics in the open-domain case we are only considering the highest ranked relevant doc | ||||||
|  |                     if found_relevant_doc: | ||||||
|  |                         break | ||||||
|         # Option 2: Strict evaluation by document ids that are listed in the labels |         # Option 2: Strict evaluation by document ids that are listed in the labels | ||||||
|         else: |         else: | ||||||
|             for question, gold_ids in tqdm(question_label_dict.items()): |             for question, gold_ids in tqdm(question_label_dict.items()): | ||||||
| @ -112,28 +121,38 @@ class BaseRetriever(ABC): | |||||||
|                 if return_preds: |                 if return_preds: | ||||||
|                     predictions.append({"question": question, "retrieved_docs": retrieved_docs}) |                     predictions.append({"question": question, "retrieved_docs": retrieved_docs}) | ||||||
|                 # check if correct doc in retrieved docs |                 # check if correct doc in retrieved docs | ||||||
|  |                 relevant_docs_found = 0 | ||||||
|  |                 found_relevant_doc = False | ||||||
|                 for doc_idx, doc in enumerate(retrieved_docs): |                 for doc_idx, doc in enumerate(retrieved_docs): | ||||||
|                     for gold_id in gold_ids: |                     for gold_id in gold_ids: | ||||||
|                         if str(doc.id) == gold_id: |                         if str(doc.id) == gold_id: | ||||||
|                             correct_retrievals += 1 |                             if not found_relevant_doc: | ||||||
|                             summed_avg_precision += 1 / (doc_idx + 1)  # type: ignore |                                 correct_retrievals += 1 | ||||||
|  |                                 summed_reciprocal_rank += 1 / (doc_idx + 1) | ||||||
|  |                             found_relevant_doc = True | ||||||
|  |                             relevant_docs_found += 1 | ||||||
|  |                             summed_avg_precision += (1 / len(gold_ids)) * (relevant_docs_found / (doc_idx + 1)) | ||||||
|                             break |                             break | ||||||
|         # Metrics |         # Metrics | ||||||
|         number_of_questions = len(question_label_dict) |         number_of_questions = len(question_label_dict) | ||||||
|         recall = correct_retrievals / number_of_questions |         recall = correct_retrievals / number_of_questions | ||||||
|         mean_avg_precision = summed_avg_precision / number_of_questions |         mean_reciprocal_rank = summed_reciprocal_rank / number_of_questions | ||||||
| 
 | 
 | ||||||
|         logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in" |         logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in" | ||||||
|                      f" the top-{top_k} candidate passages selected by the retriever.")) |                      f" the top-{top_k} candidate passages selected by the retriever.")) | ||||||
| 
 | 
 | ||||||
|         metrics =  { |         metrics =  { | ||||||
|             "recall": recall, |             "recall": recall, | ||||||
|             "map": mean_avg_precision, |             "mrr": mean_reciprocal_rank, | ||||||
|             "retrieve_time": self.retrieve_time, |             "retrieve_time": self.retrieve_time, | ||||||
|             "n_questions": number_of_questions, |             "n_questions": number_of_questions, | ||||||
|             "top_k": top_k |             "top_k": top_k | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         if not open_domain: | ||||||
|  |             mean_avg_precision = summed_avg_precision / number_of_questions | ||||||
|  |             metrics["map"] = mean_avg_precision | ||||||
|  | 
 | ||||||
|         if return_preds: |         if return_preds: | ||||||
|             return {"metrics": metrics, "predictions": predictions} |             return {"metrics": metrics, "predictions": predictions} | ||||||
|         else: |         else: | ||||||
|  | |||||||
| @ -75,7 +75,9 @@ def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain, | |||||||
|     # eval retriever |     # eval retriever | ||||||
|     results = retriever.eval(top_k=1, label_index="test_feedback", doc_index="test_eval_document", open_domain=open_domain) |     results = retriever.eval(top_k=1, label_index="test_feedback", doc_index="test_eval_document", open_domain=open_domain) | ||||||
|     assert results["recall"] == 1.0 |     assert results["recall"] == 1.0 | ||||||
|     assert results["map"] == 1.0 |     assert results["mrr"] == 1.0 | ||||||
|  |     if not open_domain: | ||||||
|  |         assert results["map"] == 1.0 | ||||||
| 
 | 
 | ||||||
|     # clean up |     # clean up | ||||||
|     document_store.delete_all_documents(index="test_eval_document") |     document_store.delete_all_documents(index="test_eval_document") | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 bogdankostic
						bogdankostic