From 18d315d61a756e9fed32a9c78ffaa5689bda001e Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Wed, 28 Oct 2020 09:55:31 +0100 Subject: [PATCH] Make returning predictions in evaluation possible (#524) * Make returning preds in evaluation possible * Make returning preds in evaluation possible * Add automated check if eval dict contains predictions --- haystack/finder.py | 30 +++++++++++++++++++++++++----- haystack/reader/farm.py | 2 +- haystack/retriever/base.py | 24 ++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/haystack/finder.py b/haystack/finder.py index 4ead6911c..94a389dd9 100644 --- a/haystack/finder.py +++ b/haystack/finder.py @@ -121,6 +121,7 @@ class Finder: label_origin: str = "gold_label", top_k_retriever: int = 10, top_k_reader: int = 10, + return_preds: bool = False, ): """ Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result @@ -165,6 +166,9 @@ class Finder: :type top_k_retriever: int :param top_k_reader: How many answers to return per question :type top_k_reader: int + :param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary + contains the keys "predictions" and "metrics". + :type return_preds: bool """ if not self.reader or not self.retriever: @@ -205,6 +209,7 @@ class Finder: previous_return_no_answers = self.reader.return_no_answers self.reader.return_no_answers = True + predictions = [] # extract answers reader_start_time = time.time() for q_idx, question_docs in enumerate(questions_with_docs): @@ -215,8 +220,10 @@ class Finder: question_string = question.question docs = question_docs["docs"] # type: ignore single_reader_start = time.time() - predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore + predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore read_times.append(time.time() - single_reader_start) + if return_preds: + predictions.append(predicted_answers) counts = eval_counts_reader(question, predicted_answers, counts) counts["number_of_has_answer"] = counts["correct_retrievals"] - counts["number_of_no_answer"] @@ -240,7 +247,10 @@ class Finder: eval_results["avg_reader_time"] = mean(read_times) eval_results["total_finder_time"] = finder_total_time - return eval_results + if return_preds: + return {"metrics": eval_results, "predictions": predictions} + else: + return eval_results def eval_batch( self, @@ -249,7 +259,8 @@ class Finder: label_origin: str = "gold_label", top_k_retriever: int = 10, top_k_reader: int = 10, - batch_size: int = 50 + batch_size: int = 50, + return_preds: bool = False, ): """ Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result @@ -296,10 +307,13 @@ class Finder: :type top_k_reader: int :param batch_size: Number of samples per batch computed at once :type batch_size: int + :param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary + contains the keys "predictions" and "metrics". + :type return_preds: bool """ if not self.reader or not self.retriever: - raise Exception("Finder needs to have a reader and retriever for the evalutaion.") + raise Exception("Finder needs to have a reader and retriever for the evaluation.") counts = defaultdict(float) # type: Dict[str, float] finder_start_time = time.time() @@ -344,7 +358,10 @@ class Finder: logger.info(f"{number_of_questions - correct_retrievals} questions could not be answered due to the retriever.") logger.info(f"{correct_retrievals - counts['correct_readings_topk']} questions could not be answered due to the reader.") - return results + if return_preds: + return {"metrics": results, "predictions": predictions} + else: + return results def _retrieve_docs(self, questions: List[MultiLabel], top_k: int, doc_index: str): @@ -364,6 +381,9 @@ class Finder: @staticmethod def print_eval_results(finder_eval_results: Dict): + if "predictions" in finder_eval_results.keys(): + finder_eval_results = finder_eval_results["metrics"] + print("\n___Retriever Metrics in Finder___") print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}") print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}") diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 36047460b..6c2bf7e90 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -291,7 +291,7 @@ class FARMReader(BaseReader): result = [] for idx, group in enumerate(grouped_predictions): answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k_per_question) - question = group[0] + question = group[0].question cur_label = labels[idx] result.append({ "question": question, diff --git a/haystack/retriever/base.py b/haystack/retriever/base.py index e73492f8f..21e97b449 100644 --- a/haystack/retriever/base.py +++ b/haystack/retriever/base.py @@ -45,7 +45,8 @@ class BaseRetriever(ABC): doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, - open_domain: bool = False + open_domain: bool = False, + return_preds: bool = False, ) -> dict: """ Performs evaluation on the Retriever. @@ -65,6 +66,8 @@ class BaseRetriever(ABC): contained in the retrieved docs (common approach in open-domain QA). If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids are within ids explicitly stated in the labels. + :param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary + contains the keys "predictions" and "metrics". """ # Extract all questions for evaluation @@ -86,11 +89,15 @@ class BaseRetriever(ABC): deduplicated_doc_ids = list(set([str(x) for x in label.multiple_document_ids])) question_label_dict[label.question] = deduplicated_doc_ids + predictions = [] + # Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs logger.info("Performing eval queries...") if open_domain: for question, gold_answers in tqdm(question_label_dict.items()): retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index) + if return_preds: + predictions.append({"question": question, "retrieved_docs": retrieved_docs}) # check if correct doc in retrieved docs for doc_idx, doc in enumerate(retrieved_docs): for gold_answer in gold_answers: @@ -102,6 +109,8 @@ class BaseRetriever(ABC): else: for question, gold_ids in tqdm(question_label_dict.items()): retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index) + if return_preds: + predictions.append({"question": question, "retrieved_docs": retrieved_docs}) # check if correct doc in retrieved docs for doc_idx, doc in enumerate(retrieved_docs): for gold_id in gold_ids: @@ -117,4 +126,15 @@ class BaseRetriever(ABC): logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in" f" the top-{top_k} candidate passages selected by the retriever.")) - return {"recall": recall, "map": mean_avg_precision, "retrieve_time": self.retrieve_time, "n_questions": number_of_questions, "top_k": top_k} \ No newline at end of file + metrics = { + "recall": recall, + "map": mean_avg_precision, + "retrieve_time": self.retrieve_time, + "n_questions": number_of_questions, + "top_k": top_k + } + + if return_preds: + return {"metrics": metrics, "predictions": predictions} + else: + return metrics \ No newline at end of file