Make returning predictions in evaluation possible (#524)

* Make returning preds in evaluation possible

* Make returning preds in evaluation possible

* Add automated check if eval dict contains predictions
This commit is contained in:
bogdankostic 2020-10-28 09:55:31 +01:00 committed by GitHub
parent 4fa5d9c3eb
commit 18d315d61a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 8 deletions

View File

@ -121,6 +121,7 @@ class Finder:
label_origin: str = "gold_label",
top_k_retriever: int = 10,
top_k_reader: int = 10,
return_preds: bool = False,
):
"""
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
@ -165,6 +166,9 @@ class Finder:
:type top_k_retriever: int
:param top_k_reader: How many answers to return per question
:type top_k_reader: int
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
contains the keys "predictions" and "metrics".
:type return_preds: bool
"""
if not self.reader or not self.retriever:
@ -205,6 +209,7 @@ class Finder:
previous_return_no_answers = self.reader.return_no_answers
self.reader.return_no_answers = True
predictions = []
# extract answers
reader_start_time = time.time()
for q_idx, question_docs in enumerate(questions_with_docs):
@ -215,8 +220,10 @@ class Finder:
question_string = question.question
docs = question_docs["docs"] # type: ignore
single_reader_start = time.time()
predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore
predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore
read_times.append(time.time() - single_reader_start)
if return_preds:
predictions.append(predicted_answers)
counts = eval_counts_reader(question, predicted_answers, counts)
counts["number_of_has_answer"] = counts["correct_retrievals"] - counts["number_of_no_answer"]
@ -240,7 +247,10 @@ class Finder:
eval_results["avg_reader_time"] = mean(read_times)
eval_results["total_finder_time"] = finder_total_time
return eval_results
if return_preds:
return {"metrics": eval_results, "predictions": predictions}
else:
return eval_results
def eval_batch(
self,
@ -249,7 +259,8 @@ class Finder:
label_origin: str = "gold_label",
top_k_retriever: int = 10,
top_k_reader: int = 10,
batch_size: int = 50
batch_size: int = 50,
return_preds: bool = False,
):
"""
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
@ -296,10 +307,13 @@ class Finder:
:type top_k_reader: int
:param batch_size: Number of samples per batch computed at once
:type batch_size: int
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
contains the keys "predictions" and "metrics".
:type return_preds: bool
"""
if not self.reader or not self.retriever:
raise Exception("Finder needs to have a reader and retriever for the evalutaion.")
raise Exception("Finder needs to have a reader and retriever for the evaluation.")
counts = defaultdict(float) # type: Dict[str, float]
finder_start_time = time.time()
@ -344,7 +358,10 @@ class Finder:
logger.info(f"{number_of_questions - correct_retrievals} questions could not be answered due to the retriever.")
logger.info(f"{correct_retrievals - counts['correct_readings_topk']} questions could not be answered due to the reader.")
return results
if return_preds:
return {"metrics": results, "predictions": predictions}
else:
return results
def _retrieve_docs(self, questions: List[MultiLabel], top_k: int, doc_index: str):
@ -364,6 +381,9 @@ class Finder:
@staticmethod
def print_eval_results(finder_eval_results: Dict):
if "predictions" in finder_eval_results.keys():
finder_eval_results = finder_eval_results["metrics"]
print("\n___Retriever Metrics in Finder___")
print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}")
print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}")

View File

@ -291,7 +291,7 @@ class FARMReader(BaseReader):
result = []
for idx, group in enumerate(grouped_predictions):
answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k_per_question)
question = group[0]
question = group[0].question
cur_label = labels[idx]
result.append({
"question": question,

View File

@ -45,7 +45,8 @@ class BaseRetriever(ABC):
doc_index: str = "eval_document",
label_origin: str = "gold_label",
top_k: int = 10,
open_domain: bool = False
open_domain: bool = False,
return_preds: bool = False,
) -> dict:
"""
Performs evaluation on the Retriever.
@ -65,6 +66,8 @@ class BaseRetriever(ABC):
contained in the retrieved docs (common approach in open-domain QA).
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
are within ids explicitly stated in the labels.
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
contains the keys "predictions" and "metrics".
"""
# Extract all questions for evaluation
@ -86,11 +89,15 @@ class BaseRetriever(ABC):
deduplicated_doc_ids = list(set([str(x) for x in label.multiple_document_ids]))
question_label_dict[label.question] = deduplicated_doc_ids
predictions = []
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
logger.info("Performing eval queries...")
if open_domain:
for question, gold_answers in tqdm(question_label_dict.items()):
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
if return_preds:
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
# check if correct doc in retrieved docs
for doc_idx, doc in enumerate(retrieved_docs):
for gold_answer in gold_answers:
@ -102,6 +109,8 @@ class BaseRetriever(ABC):
else:
for question, gold_ids in tqdm(question_label_dict.items()):
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
if return_preds:
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
# check if correct doc in retrieved docs
for doc_idx, doc in enumerate(retrieved_docs):
for gold_id in gold_ids:
@ -117,4 +126,15 @@ class BaseRetriever(ABC):
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
f" the top-{top_k} candidate passages selected by the retriever."))
return {"recall": recall, "map": mean_avg_precision, "retrieve_time": self.retrieve_time, "n_questions": number_of_questions, "top_k": top_k}
metrics = {
"recall": recall,
"map": mean_avg_precision,
"retrieve_time": self.retrieve_time,
"n_questions": number_of_questions,
"top_k": top_k
}
if return_preds:
return {"metrics": metrics, "predictions": predictions}
else:
return metrics