mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 08:04:49 +00:00
Make returning predictions in evaluation possible (#524)
* Make returning preds in evaluation possible * Make returning preds in evaluation possible * Add automated check if eval dict contains predictions
This commit is contained in:
parent
4fa5d9c3eb
commit
18d315d61a
@ -121,6 +121,7 @@ class Finder:
|
||||
label_origin: str = "gold_label",
|
||||
top_k_retriever: int = 10,
|
||||
top_k_reader: int = 10,
|
||||
return_preds: bool = False,
|
||||
):
|
||||
"""
|
||||
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
|
||||
@ -165,6 +166,9 @@ class Finder:
|
||||
:type top_k_retriever: int
|
||||
:param top_k_reader: How many answers to return per question
|
||||
:type top_k_reader: int
|
||||
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
|
||||
contains the keys "predictions" and "metrics".
|
||||
:type return_preds: bool
|
||||
"""
|
||||
|
||||
if not self.reader or not self.retriever:
|
||||
@ -205,6 +209,7 @@ class Finder:
|
||||
previous_return_no_answers = self.reader.return_no_answers
|
||||
self.reader.return_no_answers = True
|
||||
|
||||
predictions = []
|
||||
# extract answers
|
||||
reader_start_time = time.time()
|
||||
for q_idx, question_docs in enumerate(questions_with_docs):
|
||||
@ -215,8 +220,10 @@ class Finder:
|
||||
question_string = question.question
|
||||
docs = question_docs["docs"] # type: ignore
|
||||
single_reader_start = time.time()
|
||||
predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore
|
||||
predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore
|
||||
read_times.append(time.time() - single_reader_start)
|
||||
if return_preds:
|
||||
predictions.append(predicted_answers)
|
||||
counts = eval_counts_reader(question, predicted_answers, counts)
|
||||
|
||||
counts["number_of_has_answer"] = counts["correct_retrievals"] - counts["number_of_no_answer"]
|
||||
@ -240,7 +247,10 @@ class Finder:
|
||||
eval_results["avg_reader_time"] = mean(read_times)
|
||||
eval_results["total_finder_time"] = finder_total_time
|
||||
|
||||
return eval_results
|
||||
if return_preds:
|
||||
return {"metrics": eval_results, "predictions": predictions}
|
||||
else:
|
||||
return eval_results
|
||||
|
||||
def eval_batch(
|
||||
self,
|
||||
@ -249,7 +259,8 @@ class Finder:
|
||||
label_origin: str = "gold_label",
|
||||
top_k_retriever: int = 10,
|
||||
top_k_reader: int = 10,
|
||||
batch_size: int = 50
|
||||
batch_size: int = 50,
|
||||
return_preds: bool = False,
|
||||
):
|
||||
"""
|
||||
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
|
||||
@ -296,10 +307,13 @@ class Finder:
|
||||
:type top_k_reader: int
|
||||
:param batch_size: Number of samples per batch computed at once
|
||||
:type batch_size: int
|
||||
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
|
||||
contains the keys "predictions" and "metrics".
|
||||
:type return_preds: bool
|
||||
"""
|
||||
|
||||
if not self.reader or not self.retriever:
|
||||
raise Exception("Finder needs to have a reader and retriever for the evalutaion.")
|
||||
raise Exception("Finder needs to have a reader and retriever for the evaluation.")
|
||||
|
||||
counts = defaultdict(float) # type: Dict[str, float]
|
||||
finder_start_time = time.time()
|
||||
@ -344,7 +358,10 @@ class Finder:
|
||||
logger.info(f"{number_of_questions - correct_retrievals} questions could not be answered due to the retriever.")
|
||||
logger.info(f"{correct_retrievals - counts['correct_readings_topk']} questions could not be answered due to the reader.")
|
||||
|
||||
return results
|
||||
if return_preds:
|
||||
return {"metrics": results, "predictions": predictions}
|
||||
else:
|
||||
return results
|
||||
|
||||
|
||||
def _retrieve_docs(self, questions: List[MultiLabel], top_k: int, doc_index: str):
|
||||
@ -364,6 +381,9 @@ class Finder:
|
||||
|
||||
@staticmethod
|
||||
def print_eval_results(finder_eval_results: Dict):
|
||||
if "predictions" in finder_eval_results.keys():
|
||||
finder_eval_results = finder_eval_results["metrics"]
|
||||
|
||||
print("\n___Retriever Metrics in Finder___")
|
||||
print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}")
|
||||
print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}")
|
||||
|
@ -291,7 +291,7 @@ class FARMReader(BaseReader):
|
||||
result = []
|
||||
for idx, group in enumerate(grouped_predictions):
|
||||
answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k_per_question)
|
||||
question = group[0]
|
||||
question = group[0].question
|
||||
cur_label = labels[idx]
|
||||
result.append({
|
||||
"question": question,
|
||||
|
@ -45,7 +45,8 @@ class BaseRetriever(ABC):
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
top_k: int = 10,
|
||||
open_domain: bool = False
|
||||
open_domain: bool = False,
|
||||
return_preds: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Performs evaluation on the Retriever.
|
||||
@ -65,6 +66,8 @@ class BaseRetriever(ABC):
|
||||
contained in the retrieved docs (common approach in open-domain QA).
|
||||
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
|
||||
are within ids explicitly stated in the labels.
|
||||
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
|
||||
contains the keys "predictions" and "metrics".
|
||||
"""
|
||||
|
||||
# Extract all questions for evaluation
|
||||
@ -86,11 +89,15 @@ class BaseRetriever(ABC):
|
||||
deduplicated_doc_ids = list(set([str(x) for x in label.multiple_document_ids]))
|
||||
question_label_dict[label.question] = deduplicated_doc_ids
|
||||
|
||||
predictions = []
|
||||
|
||||
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
|
||||
logger.info("Performing eval queries...")
|
||||
if open_domain:
|
||||
for question, gold_answers in tqdm(question_label_dict.items()):
|
||||
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
|
||||
if return_preds:
|
||||
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
||||
# check if correct doc in retrieved docs
|
||||
for doc_idx, doc in enumerate(retrieved_docs):
|
||||
for gold_answer in gold_answers:
|
||||
@ -102,6 +109,8 @@ class BaseRetriever(ABC):
|
||||
else:
|
||||
for question, gold_ids in tqdm(question_label_dict.items()):
|
||||
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
|
||||
if return_preds:
|
||||
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
||||
# check if correct doc in retrieved docs
|
||||
for doc_idx, doc in enumerate(retrieved_docs):
|
||||
for gold_id in gold_ids:
|
||||
@ -117,4 +126,15 @@ class BaseRetriever(ABC):
|
||||
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
|
||||
f" the top-{top_k} candidate passages selected by the retriever."))
|
||||
|
||||
return {"recall": recall, "map": mean_avg_precision, "retrieve_time": self.retrieve_time, "n_questions": number_of_questions, "top_k": top_k}
|
||||
metrics = {
|
||||
"recall": recall,
|
||||
"map": mean_avg_precision,
|
||||
"retrieve_time": self.retrieve_time,
|
||||
"n_questions": number_of_questions,
|
||||
"top_k": top_k
|
||||
}
|
||||
|
||||
if return_preds:
|
||||
return {"metrics": metrics, "predictions": predictions}
|
||||
else:
|
||||
return metrics
|
Loading…
x
Reference in New Issue
Block a user