mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 04:27:15 +00:00
Add new QA eval metric: Semantic Answer Similarity (SAS) (#1338)
* init * Add type annotation * Add test case, fix mypy * Add german model to docstring Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
ba071cc052
commit
07bd3c50ea
284
haystack/eval.py
284
haystack/eval.py
@ -1,5 +1,9 @@
|
||||
from typing import List, Tuple, Dict, Any, Optional
|
||||
import logging
|
||||
from transformers import AutoConfig
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
|
||||
from haystack import MultiLabel, Label
|
||||
|
||||
@ -148,18 +152,34 @@ class EvalAnswers:
|
||||
open vs closed domain eval (https://haystack.deepset.ai/docs/latest/tutorial5md).
|
||||
"""
|
||||
|
||||
def __init__(self, skip_incorrect_retrieval: bool=True, open_domain: bool=True, debug: bool=False):
|
||||
def __init__(self,
|
||||
skip_incorrect_retrieval: bool = True,
|
||||
open_domain: bool = True,
|
||||
sas_model: str = None,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
:param skip_incorrect_retrieval: When set to True, this eval will ignore the cases where the retriever returned no correct documents
|
||||
:param open_domain: When True, extracted answers are evaluated purely on string similarity rather than the position of the extracted answer
|
||||
:param sas_model: Name or path of "Semantic Answer Similarity (SAS) model". When set, the model will be used to calculate similarity between predictions and labels and generate the SAS metric.
|
||||
The SAS metric correlates better with human judgement of correct answers as it does not rely on string overlaps.
|
||||
Example: Prediction = "30%", Label = "thirty percent", EM and F1 would be overly pessimistic with both being 0, while SAS paints a more realistic picture.
|
||||
Models:
|
||||
- You can use Bi Encoders (sentence transformers) or cross encoders trained on Semantic Textual Similarity (STS) data.
|
||||
Not all cross encoders can be used because of different return types.
|
||||
If you use custom cross encoders please make sure they work with sentence_transformers.CrossEncoder class
|
||||
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
|
||||
- Large model for German only: "deepset/gbert-large-sts"
|
||||
:param debug: When True, a record of each sample and its evaluation will be stored in EvalAnswers.log
|
||||
"""
|
||||
self.outgoing_edges = 1
|
||||
self.init_counts()
|
||||
self.log: List = []
|
||||
self.debug = debug
|
||||
self.skip_incorrect_retrieval = skip_incorrect_retrieval
|
||||
self.open_domain = open_domain
|
||||
self.sas_model = sas_model
|
||||
self.init_counts()
|
||||
|
||||
def init_counts(self):
|
||||
self.query_count = 0
|
||||
@ -176,6 +196,11 @@ class EvalAnswers:
|
||||
self.top_k_em = 0.0
|
||||
self.top_1_f1 = 0.0
|
||||
self.top_k_f1 = 0.0
|
||||
if self.sas_model is not None:
|
||||
self.top_1_sas_sum = 0
|
||||
self.top_k_sas_sum = 0
|
||||
self.top_1_sas = 0.0
|
||||
self.top_k_sas = 0.0
|
||||
|
||||
def run(self, labels, answers, **kwargs):
|
||||
"""Run this node on one sample and its labels"""
|
||||
@ -201,12 +226,27 @@ class EvalAnswers:
|
||||
self.has_answer_count += 1
|
||||
predictions = [p for p in predictions if p["answer"]]
|
||||
top_1_em, top_1_f1, top_k_em, top_k_f1 = self.evaluate_extraction(multi_labels, predictions)
|
||||
|
||||
# Compute Semantic Answer Similarity if model is supplied
|
||||
if self.sas_model is not None:
|
||||
# sas works on batches, so we pack the labels into a list of lists, and unpack the return values as well
|
||||
gold_labels = [multi_labels.multiple_answers]
|
||||
predictions_list = [[p["answer"] for p in predictions]]
|
||||
top_1_sas, top_k_sas = semantic_answer_similarity(
|
||||
predictions=predictions_list,
|
||||
gold_labels=gold_labels,
|
||||
sas_model_name_or_path=self.sas_model)
|
||||
self.top_1_sas_sum += top_1_sas[0]
|
||||
self.top_k_sas_sum += top_k_sas[0]
|
||||
|
||||
if self.debug:
|
||||
self.log.append({"predictions": predictions,
|
||||
"gold_labels": multi_labels,
|
||||
"top_k_f1": top_k_f1,
|
||||
"top_k_em": top_k_em
|
||||
})
|
||||
if self.sas_model:
|
||||
self.log[-1].update({"top_k_sas":top_k_sas})
|
||||
|
||||
self.top_1_em_count += top_1_em
|
||||
self.top_1_f1_sum += top_1_f1
|
||||
@ -233,6 +273,9 @@ class EvalAnswers:
|
||||
self.top_k_em = self.top_k_em_count / self.has_answer_count
|
||||
self.top_1_f1 = self.top_1_f1_sum / self.has_answer_count
|
||||
self.top_k_f1 = self.top_k_f1_sum / self.has_answer_count
|
||||
if self.sas_model is not None:
|
||||
self.top_1_sas = self.top_1_sas_sum / self.has_answer_count
|
||||
self.top_k_sas = self.top_k_sas_sum / self.has_answer_count
|
||||
|
||||
def update_no_answer_metrics(self):
|
||||
self.top_1_no_answer = self.top_1_no_answer_count / self.no_answer_count
|
||||
@ -248,6 +291,9 @@ class EvalAnswers:
|
||||
print(f"top k EM: {self.top_k_em:.4f}")
|
||||
print(f"top 1 F1: {self.top_1_f1:.4f}")
|
||||
print(f"top k F1: {self.top_k_f1:.4f}")
|
||||
if self.sas_model is not None:
|
||||
print(f"top 1 SAS: {self.top_1_sas:.4f}")
|
||||
print(f"top k SAS: {self.top_k_sas:.4f}")
|
||||
if self.no_answer_count:
|
||||
print()
|
||||
print(f"no_answer queries: {self.no_answer_count}")
|
||||
@ -266,11 +312,17 @@ class EvalAnswers:
|
||||
print(f"top k EM: {pipeline_top_k_em:.4f}")
|
||||
print(f"top 1 F1: {pipeline_top_1_f1:.4f}")
|
||||
print(f"top k F1: {pipeline_top_k_f1:.4f}")
|
||||
if self.sas_model is not None:
|
||||
pipeline_top_1_sas = (self.top_1_sas_sum + self.top_1_no_answer_count) / self.query_count
|
||||
pipeline_top_k_sas = (self.top_k_sas_sum + self.no_answer_count) / self.query_count
|
||||
print(f"top 1 SAS: {pipeline_top_1_sas:.4f}")
|
||||
print(f"top k SAS: {pipeline_top_k_sas:.4f}")
|
||||
if self.no_answer_count:
|
||||
print(
|
||||
"(top k results are likely inflated since the Reader always returns a no_answer prediction in its top k)"
|
||||
)
|
||||
|
||||
|
||||
def get_label(labels, node_id):
|
||||
if type(labels) in [Label, MultiLabel]:
|
||||
ret = labels
|
||||
@ -279,6 +331,7 @@ def get_label(labels, node_id):
|
||||
ret = labels[node_id]
|
||||
return ret
|
||||
|
||||
|
||||
def calculate_em_str_multi(gold_labels, prediction):
|
||||
for gold_label in gold_labels:
|
||||
result = calculate_em_str(gold_label, prediction)
|
||||
@ -295,176 +348,72 @@ def calculate_f1_str_multi(gold_labels, prediction):
|
||||
return max(results)
|
||||
|
||||
|
||||
def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals: int):
|
||||
number_of_has_answer = correct_retrievals - metric_counts["number_of_no_answer"]
|
||||
def semantic_answer_similarity(predictions: List[List[str]],
|
||||
gold_labels: List[List[str]],
|
||||
sas_model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
) -> Tuple[List[float],List[float]]:
|
||||
"""
|
||||
Computes Transformer-based similarity of predicted answer to gold labels to derive a more meaningful metric than EM or F1.
|
||||
Returns per QA pair a) the similarity of the most likely prediction (top 1) to all available gold labels
|
||||
b) the highest similarity of all predictions to gold labels
|
||||
|
||||
metrics = {
|
||||
"reader_top1_accuracy" : metric_counts["correct_readings_top1"] / correct_retrievals,
|
||||
"reader_top1_accuracy_has_answer" : metric_counts["correct_readings_top1_has_answer"] / number_of_has_answer,
|
||||
"reader_topk_accuracy" : metric_counts["correct_readings_topk"] / correct_retrievals,
|
||||
"reader_topk_accuracy_has_answer" : metric_counts["correct_readings_topk_has_answer"] / number_of_has_answer,
|
||||
"reader_top1_em" : metric_counts["exact_matches_top1"] / correct_retrievals,
|
||||
"reader_top1_em_has_answer" : metric_counts["exact_matches_top1_has_answer"] / number_of_has_answer,
|
||||
"reader_topk_em" : metric_counts["exact_matches_topk"] / correct_retrievals,
|
||||
"reader_topk_em_has_answer" : metric_counts["exact_matches_topk_has_answer"] / number_of_has_answer,
|
||||
"reader_top1_f1" : metric_counts["summed_f1_top1"] / correct_retrievals,
|
||||
"reader_top1_f1_has_answer" : metric_counts["summed_f1_top1_has_answer"] / number_of_has_answer,
|
||||
"reader_topk_f1" : metric_counts["summed_f1_topk"] / correct_retrievals,
|
||||
"reader_topk_f1_has_answer" : metric_counts["summed_f1_topk_has_answer"] / number_of_has_answer,
|
||||
}
|
||||
:param predictions: Predicted answers as list of multiple preds per question
|
||||
:param gold_labels: Labels as list of multiple possible answers per question
|
||||
:param sas_model_name_or_path: SentenceTransformers semantic textual similarity model, should be path or string
|
||||
pointing to downloadable models.
|
||||
|
||||
if metric_counts["number_of_no_answer"]:
|
||||
metrics["reader_top1_no_answer_accuracy"] = metric_counts["correct_no_answers_top1"] / metric_counts[
|
||||
"number_of_no_answer"]
|
||||
metrics["reader_topk_no_answer_accuracy"] = metric_counts["correct_no_answers_topk"] / metric_counts[
|
||||
"number_of_no_answer"]
|
||||
|
||||
:return top_1_sas, top_k_sas
|
||||
"""
|
||||
assert len(predictions) == len(gold_labels)
|
||||
|
||||
config = AutoConfig.from_pretrained(sas_model_name_or_path)
|
||||
cross_encoder_used = False
|
||||
if config.architectures is not None:
|
||||
cross_encoder_used = any([arch.endswith('ForSequenceClassification') for arch in config.architectures])
|
||||
|
||||
# Compute similarities
|
||||
top_1_sas = []
|
||||
top_k_sas = []
|
||||
|
||||
# Based on Modelstring we can load either Bi-Encoders or Cross Encoders.
|
||||
# Similarity computation changes for both approaches
|
||||
if cross_encoder_used:
|
||||
model = CrossEncoder(sas_model_name_or_path)
|
||||
for preds, labels in zip (predictions,gold_labels):
|
||||
# TODO add efficient batch mode: put all texts and labels into grid and extract scores afterwards
|
||||
grid = []
|
||||
for p in preds:
|
||||
for l in labels:
|
||||
grid.append((p,l))
|
||||
scores = model.predict(grid)
|
||||
top_1_sas.append(np.max(scores[:len(labels)]))
|
||||
top_k_sas.append(np.max(scores))
|
||||
else:
|
||||
metrics["reader_top1_no_answer_accuracy"] = None # type: ignore
|
||||
metrics["reader_topk_no_answer_accuracy"] = None # type: ignore
|
||||
# For Bi-encoders we can flatten predictions and labels into one list
|
||||
model = SentenceTransformer(sas_model_name_or_path)
|
||||
lengths: List[Tuple[int,int]] = []
|
||||
all_texts: List[str] = []
|
||||
for p, l in zip(predictions, gold_labels): # type: ignore
|
||||
# TODO potentially exclude (near) exact matches from computations
|
||||
all_texts.extend(p)
|
||||
all_texts.extend(l)
|
||||
lengths.append((len(p), len(l)))
|
||||
# then compute embeddings
|
||||
embeddings = model.encode(all_texts)
|
||||
|
||||
return metrics
|
||||
# then select which embeddings will be used for similarity computations
|
||||
current_position = 0
|
||||
for i, (len_p, len_l) in enumerate(lengths):
|
||||
pred_embeddings = embeddings[current_position:current_position + len_p, :]
|
||||
current_position += len_p
|
||||
label_embeddings = embeddings[current_position:current_position + len_l, :]
|
||||
current_position += len_l
|
||||
sims = cosine_similarity(pred_embeddings, label_embeddings)
|
||||
top_1_sas.append(np.max(sims[0, :]))
|
||||
top_k_sas.append(np.max(sims))
|
||||
|
||||
|
||||
def calculate_average_precision_and_reciprocal_rank(questions_with_docs: List[dict]):
|
||||
questions_with_correct_doc = []
|
||||
summed_avg_precision_retriever = 0.0
|
||||
summed_reciprocal_rank_retriever = 0.0
|
||||
|
||||
for question in questions_with_docs:
|
||||
number_relevant_docs = len(set(question["question"].multiple_document_ids))
|
||||
found_relevant_doc = False
|
||||
relevant_docs_found = 0
|
||||
current_avg_precision = 0.0
|
||||
for doc_idx, doc in enumerate(question["docs"]):
|
||||
# check if correct doc among retrieved docs
|
||||
if doc.id in question["question"].multiple_document_ids:
|
||||
if not found_relevant_doc:
|
||||
summed_reciprocal_rank_retriever += 1 / (doc_idx + 1)
|
||||
relevant_docs_found += 1
|
||||
found_relevant_doc = True
|
||||
current_avg_precision += relevant_docs_found / (doc_idx + 1)
|
||||
if relevant_docs_found == number_relevant_docs:
|
||||
break
|
||||
if found_relevant_doc:
|
||||
all_relevant_docs = len(set(question["question"].multiple_document_ids))
|
||||
summed_avg_precision_retriever += current_avg_precision / all_relevant_docs
|
||||
|
||||
if found_relevant_doc:
|
||||
questions_with_correct_doc.append({
|
||||
"question": question["question"],
|
||||
"docs": question["docs"]
|
||||
})
|
||||
|
||||
return questions_with_correct_doc, summed_avg_precision_retriever, summed_reciprocal_rank_retriever
|
||||
|
||||
|
||||
def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]):
|
||||
# Calculates evaluation metrics for one question and adds results to counter.
|
||||
# check if question is answerable
|
||||
if not question.no_answer:
|
||||
found_answer = False
|
||||
found_em = False
|
||||
best_f1 = 0
|
||||
for answer_idx, answer in enumerate(predicted_answers["answers"]):
|
||||
if answer["document_id"] in question.multiple_document_ids:
|
||||
gold_spans = [{"offset_start": question.multiple_offset_start_in_docs[i],
|
||||
"offset_end": question.multiple_offset_start_in_docs[i] + len(question.multiple_answers[i]),
|
||||
"doc_id": question.multiple_document_ids[i]} for i in range(len(question.multiple_answers))] # type: ignore
|
||||
predicted_span = {"offset_start": answer["offset_start_in_doc"],
|
||||
"offset_end": answer["offset_end_in_doc"],
|
||||
"doc_id": answer["document_id"]}
|
||||
best_f1_in_gold_spans = 0
|
||||
for gold_span in gold_spans:
|
||||
if gold_span["doc_id"] == predicted_span["doc_id"]:
|
||||
# check if overlap between gold answer and predicted answer
|
||||
if not found_answer:
|
||||
metric_counts, found_answer = _count_overlap(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore
|
||||
|
||||
# check for exact match
|
||||
if not found_em:
|
||||
metric_counts, found_em = _count_exact_match(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore
|
||||
|
||||
# calculate f1
|
||||
current_f1 = _calculate_f1(gold_span, predicted_span) # type: ignore
|
||||
if current_f1 > best_f1_in_gold_spans:
|
||||
best_f1_in_gold_spans = current_f1
|
||||
# top-1 f1
|
||||
if answer_idx == 0:
|
||||
metric_counts["summed_f1_top1"] += best_f1_in_gold_spans
|
||||
metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans
|
||||
if best_f1_in_gold_spans > best_f1:
|
||||
best_f1 = best_f1_in_gold_spans
|
||||
|
||||
if found_em:
|
||||
break
|
||||
# top-k answers: use best f1-score
|
||||
metric_counts["summed_f1_topk"] += best_f1
|
||||
metric_counts["summed_f1_topk_has_answer"] += best_f1
|
||||
|
||||
# question not answerable
|
||||
else:
|
||||
metric_counts["number_of_no_answer"] += 1
|
||||
metric_counts = _count_no_answer(predicted_answers["answers"], metric_counts)
|
||||
|
||||
return metric_counts
|
||||
|
||||
|
||||
def eval_counts_reader_batch(pred: Dict[str, Any], metric_counts: Dict[str, float]):
|
||||
# Calculates evaluation metrics for one question and adds results to counter.
|
||||
|
||||
# check if question is answerable
|
||||
if not pred["label"].no_answer:
|
||||
found_answer = False
|
||||
found_em = False
|
||||
best_f1 = 0
|
||||
for answer_idx, answer in enumerate(pred["answers"]):
|
||||
# check if correct document:
|
||||
if answer["document_id"] in pred["label"].multiple_document_ids:
|
||||
gold_spans = [{"offset_start": pred["label"].multiple_offset_start_in_docs[i],
|
||||
"offset_end": pred["label"].multiple_offset_start_in_docs[i] + len(pred["label"].multiple_answers[i]),
|
||||
"doc_id": pred["label"].multiple_document_ids[i]}
|
||||
for i in range(len(pred["label"].multiple_answers))] # type: ignore
|
||||
predicted_span = {"offset_start": answer["offset_start_in_doc"],
|
||||
"offset_end": answer["offset_end_in_doc"],
|
||||
"doc_id": answer["document_id"]}
|
||||
|
||||
best_f1_in_gold_spans = 0
|
||||
for gold_span in gold_spans:
|
||||
if gold_span["doc_id"] == predicted_span["doc_id"]:
|
||||
# check if overlap between gold answer and predicted answer
|
||||
if not found_answer:
|
||||
metric_counts, found_answer = _count_overlap(
|
||||
gold_span, predicted_span, metric_counts, answer_idx
|
||||
)
|
||||
# check for exact match
|
||||
if not found_em:
|
||||
metric_counts, found_em = _count_exact_match(
|
||||
gold_span, predicted_span, metric_counts, answer_idx
|
||||
)
|
||||
# calculate f1
|
||||
current_f1 = _calculate_f1(gold_span, predicted_span)
|
||||
if current_f1 > best_f1_in_gold_spans:
|
||||
best_f1_in_gold_spans = current_f1
|
||||
# top-1 f1
|
||||
if answer_idx == 0:
|
||||
metric_counts["summed_f1_top1"] += best_f1_in_gold_spans
|
||||
metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans
|
||||
if best_f1_in_gold_spans > best_f1:
|
||||
best_f1 = best_f1_in_gold_spans
|
||||
|
||||
if found_em:
|
||||
break
|
||||
|
||||
# top-k answers: use best f1-score
|
||||
metric_counts["summed_f1_topk"] += best_f1
|
||||
metric_counts["summed_f1_topk_has_answer"] += best_f1
|
||||
|
||||
# question not answerable
|
||||
else:
|
||||
metric_counts["number_of_no_answer"] += 1
|
||||
metric_counts = _count_no_answer(pred["answers"], metric_counts)
|
||||
|
||||
return metric_counts
|
||||
return top_1_sas, top_k_sas
|
||||
|
||||
|
||||
def _count_overlap(
|
||||
@ -554,3 +503,4 @@ def _count_no_answer(answers: List[dict], metric_counts: Dict[str, float]):
|
||||
break
|
||||
|
||||
return metric_counts
|
||||
|
||||
|
||||
@ -106,7 +106,9 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever):
|
||||
labels = document_store.get_all_labels_aggregated(index="haystack_test_feedback")
|
||||
|
||||
eval_retriever = EvalDocuments()
|
||||
eval_reader = EvalAnswers()
|
||||
eval_reader = EvalAnswers(sas_model="sentence-transformers/paraphrase-MiniLM-L3-v2",debug=True)
|
||||
eval_reader_cross = EvalAnswers(sas_model="cross-encoder/stsb-TinyBERT-L-4",debug=True)
|
||||
eval_reader_vanila = EvalAnswers()
|
||||
|
||||
assert document_store.get_document_count(index="haystack_test_eval_document") == 2
|
||||
p = Pipeline()
|
||||
@ -114,6 +116,8 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever):
|
||||
p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"])
|
||||
p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"])
|
||||
p.add_node(component=eval_reader_cross, name="EvalAnswers_cross", inputs=["QAReader"])
|
||||
p.add_node(component=eval_reader_vanila, name="EvalAnswers_vanilla", inputs=["QAReader"])
|
||||
for l in labels:
|
||||
res = p.run(
|
||||
query=l.question,
|
||||
@ -125,6 +129,9 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever):
|
||||
assert eval_retriever.recall == 1.0
|
||||
assert round(eval_reader.top_k_f1, 4) == 0.8333
|
||||
assert eval_reader.top_k_em == 0.5
|
||||
assert round(eval_reader.top_k_sas, 3) == 0.800
|
||||
assert round(eval_reader_cross.top_k_sas, 3) == 0.671
|
||||
assert eval_reader.top_k_em == eval_reader_vanila.top_k_em
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_eval_data_split_word(document_store):
|
||||
|
||||
@ -99,7 +99,7 @@ def tutorial5_evaluation():
|
||||
|
||||
# Here we initialize the nodes that perform evaluation
|
||||
eval_retriever = EvalDocuments()
|
||||
eval_reader = EvalAnswers()
|
||||
eval_reader = EvalAnswers(sas_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
||||
|
||||
|
||||
## Evaluate Retriever on its own in closed domain fashion
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user