mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 16:46:58 +00:00
Make returning predictions in evaluation possible (#524)
* Make returning preds in evaluation possible * Make returning preds in evaluation possible * Add automated check if eval dict contains predictions
This commit is contained in:
parent
4fa5d9c3eb
commit
18d315d61a
@ -121,6 +121,7 @@ class Finder:
|
|||||||
label_origin: str = "gold_label",
|
label_origin: str = "gold_label",
|
||||||
top_k_retriever: int = 10,
|
top_k_retriever: int = 10,
|
||||||
top_k_reader: int = 10,
|
top_k_reader: int = 10,
|
||||||
|
return_preds: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
|
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
|
||||||
@ -165,6 +166,9 @@ class Finder:
|
|||||||
:type top_k_retriever: int
|
:type top_k_retriever: int
|
||||||
:param top_k_reader: How many answers to return per question
|
:param top_k_reader: How many answers to return per question
|
||||||
:type top_k_reader: int
|
:type top_k_reader: int
|
||||||
|
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
|
||||||
|
contains the keys "predictions" and "metrics".
|
||||||
|
:type return_preds: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.reader or not self.retriever:
|
if not self.reader or not self.retriever:
|
||||||
@ -205,6 +209,7 @@ class Finder:
|
|||||||
previous_return_no_answers = self.reader.return_no_answers
|
previous_return_no_answers = self.reader.return_no_answers
|
||||||
self.reader.return_no_answers = True
|
self.reader.return_no_answers = True
|
||||||
|
|
||||||
|
predictions = []
|
||||||
# extract answers
|
# extract answers
|
||||||
reader_start_time = time.time()
|
reader_start_time = time.time()
|
||||||
for q_idx, question_docs in enumerate(questions_with_docs):
|
for q_idx, question_docs in enumerate(questions_with_docs):
|
||||||
@ -217,6 +222,8 @@ class Finder:
|
|||||||
single_reader_start = time.time()
|
single_reader_start = time.time()
|
||||||
predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore
|
predicted_answers = self.reader.predict(question_string, docs, top_k=top_k_reader) # type: ignore
|
||||||
read_times.append(time.time() - single_reader_start)
|
read_times.append(time.time() - single_reader_start)
|
||||||
|
if return_preds:
|
||||||
|
predictions.append(predicted_answers)
|
||||||
counts = eval_counts_reader(question, predicted_answers, counts)
|
counts = eval_counts_reader(question, predicted_answers, counts)
|
||||||
|
|
||||||
counts["number_of_has_answer"] = counts["correct_retrievals"] - counts["number_of_no_answer"]
|
counts["number_of_has_answer"] = counts["correct_retrievals"] - counts["number_of_no_answer"]
|
||||||
@ -240,6 +247,9 @@ class Finder:
|
|||||||
eval_results["avg_reader_time"] = mean(read_times)
|
eval_results["avg_reader_time"] = mean(read_times)
|
||||||
eval_results["total_finder_time"] = finder_total_time
|
eval_results["total_finder_time"] = finder_total_time
|
||||||
|
|
||||||
|
if return_preds:
|
||||||
|
return {"metrics": eval_results, "predictions": predictions}
|
||||||
|
else:
|
||||||
return eval_results
|
return eval_results
|
||||||
|
|
||||||
def eval_batch(
|
def eval_batch(
|
||||||
@ -249,7 +259,8 @@ class Finder:
|
|||||||
label_origin: str = "gold_label",
|
label_origin: str = "gold_label",
|
||||||
top_k_retriever: int = 10,
|
top_k_retriever: int = 10,
|
||||||
top_k_reader: int = 10,
|
top_k_reader: int = 10,
|
||||||
batch_size: int = 50
|
batch_size: int = 50,
|
||||||
|
return_preds: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
|
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
|
||||||
@ -296,10 +307,13 @@ class Finder:
|
|||||||
:type top_k_reader: int
|
:type top_k_reader: int
|
||||||
:param batch_size: Number of samples per batch computed at once
|
:param batch_size: Number of samples per batch computed at once
|
||||||
:type batch_size: int
|
:type batch_size: int
|
||||||
|
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
|
||||||
|
contains the keys "predictions" and "metrics".
|
||||||
|
:type return_preds: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.reader or not self.retriever:
|
if not self.reader or not self.retriever:
|
||||||
raise Exception("Finder needs to have a reader and retriever for the evalutaion.")
|
raise Exception("Finder needs to have a reader and retriever for the evaluation.")
|
||||||
|
|
||||||
counts = defaultdict(float) # type: Dict[str, float]
|
counts = defaultdict(float) # type: Dict[str, float]
|
||||||
finder_start_time = time.time()
|
finder_start_time = time.time()
|
||||||
@ -344,6 +358,9 @@ class Finder:
|
|||||||
logger.info(f"{number_of_questions - correct_retrievals} questions could not be answered due to the retriever.")
|
logger.info(f"{number_of_questions - correct_retrievals} questions could not be answered due to the retriever.")
|
||||||
logger.info(f"{correct_retrievals - counts['correct_readings_topk']} questions could not be answered due to the reader.")
|
logger.info(f"{correct_retrievals - counts['correct_readings_topk']} questions could not be answered due to the reader.")
|
||||||
|
|
||||||
|
if return_preds:
|
||||||
|
return {"metrics": results, "predictions": predictions}
|
||||||
|
else:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -364,6 +381,9 @@ class Finder:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def print_eval_results(finder_eval_results: Dict):
|
def print_eval_results(finder_eval_results: Dict):
|
||||||
|
if "predictions" in finder_eval_results.keys():
|
||||||
|
finder_eval_results = finder_eval_results["metrics"]
|
||||||
|
|
||||||
print("\n___Retriever Metrics in Finder___")
|
print("\n___Retriever Metrics in Finder___")
|
||||||
print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}")
|
print(f"Retriever Recall : {finder_eval_results['retriever_recall']:.3f}")
|
||||||
print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}")
|
print(f"Retriever Mean Avg Precision: {finder_eval_results['retriever_map']:.3f}")
|
||||||
|
@ -291,7 +291,7 @@ class FARMReader(BaseReader):
|
|||||||
result = []
|
result = []
|
||||||
for idx, group in enumerate(grouped_predictions):
|
for idx, group in enumerate(grouped_predictions):
|
||||||
answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k_per_question)
|
answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k_per_question)
|
||||||
question = group[0]
|
question = group[0].question
|
||||||
cur_label = labels[idx]
|
cur_label = labels[idx]
|
||||||
result.append({
|
result.append({
|
||||||
"question": question,
|
"question": question,
|
||||||
|
@ -45,7 +45,8 @@ class BaseRetriever(ABC):
|
|||||||
doc_index: str = "eval_document",
|
doc_index: str = "eval_document",
|
||||||
label_origin: str = "gold_label",
|
label_origin: str = "gold_label",
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
open_domain: bool = False
|
open_domain: bool = False,
|
||||||
|
return_preds: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Performs evaluation on the Retriever.
|
Performs evaluation on the Retriever.
|
||||||
@ -65,6 +66,8 @@ class BaseRetriever(ABC):
|
|||||||
contained in the retrieved docs (common approach in open-domain QA).
|
contained in the retrieved docs (common approach in open-domain QA).
|
||||||
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
|
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
|
||||||
are within ids explicitly stated in the labels.
|
are within ids explicitly stated in the labels.
|
||||||
|
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
|
||||||
|
contains the keys "predictions" and "metrics".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Extract all questions for evaluation
|
# Extract all questions for evaluation
|
||||||
@ -86,11 +89,15 @@ class BaseRetriever(ABC):
|
|||||||
deduplicated_doc_ids = list(set([str(x) for x in label.multiple_document_ids]))
|
deduplicated_doc_ids = list(set([str(x) for x in label.multiple_document_ids]))
|
||||||
question_label_dict[label.question] = deduplicated_doc_ids
|
question_label_dict[label.question] = deduplicated_doc_ids
|
||||||
|
|
||||||
|
predictions = []
|
||||||
|
|
||||||
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
|
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
|
||||||
logger.info("Performing eval queries...")
|
logger.info("Performing eval queries...")
|
||||||
if open_domain:
|
if open_domain:
|
||||||
for question, gold_answers in tqdm(question_label_dict.items()):
|
for question, gold_answers in tqdm(question_label_dict.items()):
|
||||||
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
|
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
|
||||||
|
if return_preds:
|
||||||
|
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
||||||
# check if correct doc in retrieved docs
|
# check if correct doc in retrieved docs
|
||||||
for doc_idx, doc in enumerate(retrieved_docs):
|
for doc_idx, doc in enumerate(retrieved_docs):
|
||||||
for gold_answer in gold_answers:
|
for gold_answer in gold_answers:
|
||||||
@ -102,6 +109,8 @@ class BaseRetriever(ABC):
|
|||||||
else:
|
else:
|
||||||
for question, gold_ids in tqdm(question_label_dict.items()):
|
for question, gold_ids in tqdm(question_label_dict.items()):
|
||||||
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
|
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
|
||||||
|
if return_preds:
|
||||||
|
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
||||||
# check if correct doc in retrieved docs
|
# check if correct doc in retrieved docs
|
||||||
for doc_idx, doc in enumerate(retrieved_docs):
|
for doc_idx, doc in enumerate(retrieved_docs):
|
||||||
for gold_id in gold_ids:
|
for gold_id in gold_ids:
|
||||||
@ -117,4 +126,15 @@ class BaseRetriever(ABC):
|
|||||||
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
|
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
|
||||||
f" the top-{top_k} candidate passages selected by the retriever."))
|
f" the top-{top_k} candidate passages selected by the retriever."))
|
||||||
|
|
||||||
return {"recall": recall, "map": mean_avg_precision, "retrieve_time": self.retrieve_time, "n_questions": number_of_questions, "top_k": top_k}
|
metrics = {
|
||||||
|
"recall": recall,
|
||||||
|
"map": mean_avg_precision,
|
||||||
|
"retrieve_time": self.retrieve_time,
|
||||||
|
"n_questions": number_of_questions,
|
||||||
|
"top_k": top_k
|
||||||
|
}
|
||||||
|
|
||||||
|
if return_preds:
|
||||||
|
return {"metrics": metrics, "predictions": predictions}
|
||||||
|
else:
|
||||||
|
return metrics
|
Loading…
x
Reference in New Issue
Block a user