Make returning predictions in evaluation possible (#524)

* Make returning preds in evaluation possible

* Make returning preds in evaluation possible

* Add automated check if eval dict contains predictions
This commit is contained in:
bogdankostic 2020-10-28 09:55:31 +01:00 committed by GitHub
parent 4fa5d9c3eb
commit 18d315d61a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 8 deletions

View File

@ -121,6 +121,7 @@ class Finder:
label_origin: str = "gold_label", label_origin: str = "gold_label",
top_k_retriever: int = 10, top_k_retriever: int = 10,
top_k_reader: int = 10, top_k_reader: int = 10,
return_preds: bool = False,
): ):
""" """
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
@ -165,6 +166,9 @@ class Finder:
:type top_k_retriever: int :type top_k_retriever: int
:param top_k_reader: How many answers to return per question :param top_k_reader: How many answers to return per question
:type top_k_reader: int :type top_k_reader: int
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
contains the keys "predictions" and "metrics".
:type return_preds: bool
""" """
if not self.reader or not self.retriever: if not self.reader or not self.retriever:
@ -205,6 +209,7 @@ class Finder:
previous_return_no_answers = self.reader.return_no_answers previous_return_no_answers = self.reader.return_no_answers
self.reader.return_no_answers = True self.reader.return_no_answers = True
predictions = []
# extract answers # extract answers
reader_start_time = time.time() reader_start_time = time.time()
for q_idx, question_docs in enumerate(questions_with_docs): for q_idx, question_docs in enumerate(questions_with_docs):
@ -217,6 +222,8 @@ class Finder:
single_reader_start = time.time() single_reader_start = time.time()
predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore
read_times.append(time.time() - single_reader_start) read_times.append(time.time() - single_reader_start)
if return_preds:
predictions.append(predicted_answers)
counts = eval_counts_reader(question, predicted_answers, counts) counts = eval_counts_reader(question, predicted_answers, counts)
counts["number_of_has_answer"] = counts["correct_retrievals"] - counts["number_of_no_answer"] counts["number_of_has_answer"] = counts["correct_retrievals"] - counts["number_of_no_answer"]
@ -240,6 +247,9 @@ class Finder:
eval_results["avg_reader_time"] = mean(read_times) eval_results["avg_reader_time"] = mean(read_times)
eval_results["total_finder_time"] = finder_total_time eval_results["total_finder_time"] = finder_total_time
if return_preds:
return {"metrics": eval_results, "predictions": predictions}
else:
return eval_results return eval_results
def eval_batch( def eval_batch(
@ -249,7 +259,8 @@ class Finder:
label_origin: str = "gold_label", label_origin: str = "gold_label",
top_k_retriever: int = 10, top_k_retriever: int = 10,
top_k_reader: int = 10, top_k_reader: int = 10,
batch_size: int = 50 batch_size: int = 50,
return_preds: bool = False,
): ):
""" """
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
@ -296,10 +307,13 @@ class Finder:
:type top_k_reader: int :type top_k_reader: int
:param batch_size: Number of samples per batch computed at once :param batch_size: Number of samples per batch computed at once
:type batch_size: int :type batch_size: int
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
contains the keys "predictions" and "metrics".
:type return_preds: bool
""" """
if not self.reader or not self.retriever: if not self.reader or not self.retriever:
raise Exception("Finder needs to have a reader and retriever for the evalutaion.") raise Exception("Finder needs to have a reader and retriever for the evaluation.")
counts = defaultdict(float) # type: Dict[str, float] counts = defaultdict(float) # type: Dict[str, float]
finder_start_time = time.time() finder_start_time = time.time()
@ -344,6 +358,9 @@ class Finder:
logger.info(f"{number_of_questions - correct_retrievals} questions could not be answered due to the retriever.") logger.info(f"{number_of_questions - correct_retrievals} questions could not be answered due to the retriever.")
logger.info(f"{correct_retrievals - counts['correct_readings_topk']} questions could not be answered due to the reader.") logger.info(f"{correct_retrievals - counts['correct_readings_topk']} questions could not be answered due to the reader.")
if return_preds:
return {"metrics": results, "predictions": predictions}
else:
return results return results
@ -364,6 +381,9 @@ class Finder:
@staticmethod @staticmethod
def print_eval_results(finder_eval_results: Dict): def print_eval_results(finder_eval_results: Dict):
if "predictions" in finder_eval_results.keys():
finder_eval_results = finder_eval_results["metrics"]
print("\n___Retriever Metrics in Finder___") print("\n___Retriever Metrics in Finder___")
print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}") print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}")
print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}") print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}")

View File

@ -291,7 +291,7 @@ class FARMReader(BaseReader):
result = [] result = []
for idx, group in enumerate(grouped_predictions): for idx, group in enumerate(grouped_predictions):
answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k_per_question) answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k_per_question)
question = group[0] question = group[0].question
cur_label = labels[idx] cur_label = labels[idx]
result.append({ result.append({
"question": question, "question": question,

View File

@ -45,7 +45,8 @@ class BaseRetriever(ABC):
doc_index: str = "eval_document", doc_index: str = "eval_document",
label_origin: str = "gold_label", label_origin: str = "gold_label",
top_k: int = 10, top_k: int = 10,
open_domain: bool = False open_domain: bool = False,
return_preds: bool = False,
) -> dict: ) -> dict:
""" """
Performs evaluation on the Retriever. Performs evaluation on the Retriever.
@ -65,6 +66,8 @@ class BaseRetriever(ABC):
contained in the retrieved docs (common approach in open-domain QA). contained in the retrieved docs (common approach in open-domain QA).
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
are within ids explicitly stated in the labels. are within ids explicitly stated in the labels.
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
contains the keys "predictions" and "metrics".
""" """
# Extract all questions for evaluation # Extract all questions for evaluation
@ -86,11 +89,15 @@ class BaseRetriever(ABC):
deduplicated_doc_ids = list(set([str(x) for x in label.multiple_document_ids])) deduplicated_doc_ids = list(set([str(x) for x in label.multiple_document_ids]))
question_label_dict[label.question] = deduplicated_doc_ids question_label_dict[label.question] = deduplicated_doc_ids
predictions = []
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs # Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
logger.info("Performing eval queries...") logger.info("Performing eval queries...")
if open_domain: if open_domain:
for question, gold_answers in tqdm(question_label_dict.items()): for question, gold_answers in tqdm(question_label_dict.items()):
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index) retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
if return_preds:
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
# check if correct doc in retrieved docs # check if correct doc in retrieved docs
for doc_idx, doc in enumerate(retrieved_docs): for doc_idx, doc in enumerate(retrieved_docs):
for gold_answer in gold_answers: for gold_answer in gold_answers:
@ -102,6 +109,8 @@ class BaseRetriever(ABC):
else: else:
for question, gold_ids in tqdm(question_label_dict.items()): for question, gold_ids in tqdm(question_label_dict.items()):
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index) retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
if return_preds:
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
# check if correct doc in retrieved docs # check if correct doc in retrieved docs
for doc_idx, doc in enumerate(retrieved_docs): for doc_idx, doc in enumerate(retrieved_docs):
for gold_id in gold_ids: for gold_id in gold_ids:
@ -117,4 +126,15 @@ class BaseRetriever(ABC):
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in" logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
f" the top-{top_k} candidate passages selected by the retriever.")) f" the top-{top_k} candidate passages selected by the retriever."))
return {"recall": recall, "map": mean_avg_precision, "retrieve_time": self.retrieve_time, "n_questions": number_of_questions, "top_k": top_k} metrics = {
"recall": recall,
"map": mean_avg_precision,
"retrieve_time": self.retrieve_time,
"n_questions": number_of_questions,
"top_k": top_k
}
if return_preds:
return {"metrics": metrics, "predictions": predictions}
else:
return metrics