mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-20 06:28:39 +00:00
Match answer sorting in QuestionAnsweringHead
with FARMReader
(#2414)
* match no_answer confidence * Update Documentation & Code Style * test added * Update Documentation & Code Style * fix tests * Update Documentation & Code Style * apply penalties of scores to confidences too Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
4bf470286b
commit
25475a68c7
@ -430,7 +430,7 @@ or use the Reader's device by default.
|
||||
#### eval
|
||||
|
||||
```python
|
||||
def eval(document_store: BaseDocumentStore, device: Optional[Union[str, torch.device]] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold-label", calibrate_conf_scores: bool = False)
|
||||
def eval(document_store: BaseDocumentStore, device: Optional[Union[str, torch.device]] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold-label", calibrate_conf_scores: bool = False, use_no_answer_legacy_confidence=False)
|
||||
```
|
||||
|
||||
Performs evaluation on evaluation documents in the DocumentStore.
|
||||
@ -450,6 +450,8 @@ or use the Reader's device by default.
|
||||
- `doc_index`: Index/Table name where documents that are used for evaluation are stored
|
||||
- `label_origin`: Field name where the gold labels are stored
|
||||
- `calibrate_conf_scores`: Whether to calibrate the temperature for temperature scaling of the confidence scores
|
||||
- `use_no_answer_legacy_confidence`: Whether to use the legacy confidence definition for no_answer: difference between the best overall answer confidence and the no_answer gap confidence.
|
||||
Otherwise we use the no_answer score normalized to a range of [0,1] by an expit function (default).
|
||||
|
||||
<a id="farm.FARMReader.calibrate_confidence_scores"></a>
|
||||
|
||||
|
@ -33,7 +33,12 @@ class Evaluator:
|
||||
self.report = report
|
||||
|
||||
def eval(
|
||||
self, model: AdaptiveModel, return_preds_and_labels: bool = False, calibrate_conf_scores: bool = False
|
||||
self,
|
||||
model: AdaptiveModel,
|
||||
return_preds_and_labels: bool = False,
|
||||
calibrate_conf_scores: bool = False,
|
||||
use_confidence_scores_for_ranking=True,
|
||||
use_no_answer_legacy_confidence=False,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Performs evaluation on a given model.
|
||||
@ -41,9 +46,14 @@ class Evaluator:
|
||||
:param model: The model on which to perform evaluation
|
||||
:param return_preds_and_labels: Whether to add preds and labels in the returned dicts of the
|
||||
:param calibrate_conf_scores: Whether to calibrate the temperature for temperature scaling of the confidence scores
|
||||
:param use_confidence_scores_for_ranking: Whether to sort answers by confidence score (normalized between 0 and 1)(default) or by standard score (unbounded).
|
||||
:param use_no_answer_legacy_confidence: Whether to use the legacy confidence definition for no_answer: difference between the best overall answer confidence and the no_answer gap confidence.
|
||||
Otherwise we use the no_answer score normalized to a range of [0,1] by an expit function (default).
|
||||
:return: all_results: A list of dictionaries, one for each prediction head. Each dictionary contains the metrics
|
||||
and reports generated during evaluation.
|
||||
"""
|
||||
model.prediction_heads[0].use_confidence_scores_for_ranking = use_confidence_scores_for_ranking
|
||||
model.prediction_heads[0].use_no_answer_legacy_confidence = use_no_answer_legacy_confidence
|
||||
model.eval()
|
||||
|
||||
# init empty lists per prediction head
|
||||
|
@ -10,6 +10,7 @@ from torch import nn
|
||||
from torch import optim
|
||||
from torch.nn import CrossEntropyLoss, NLLLoss
|
||||
from transformers import AutoModelForQuestionAnswering
|
||||
from scipy.special import expit
|
||||
|
||||
from haystack.modeling.data_handler.samples import SampleBasket
|
||||
from haystack.modeling.model.predictions import QACandidate, QAPred
|
||||
@ -234,7 +235,8 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
n_best_per_sample: Optional[int] = None,
|
||||
duplicate_filtering: int = -1,
|
||||
temperature_for_confidence: float = 1.0,
|
||||
use_confidence_scores_for_ranking: bool = False,
|
||||
use_confidence_scores_for_ranking: bool = True,
|
||||
use_no_answer_legacy_confidence: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -250,7 +252,9 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
:param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered.
|
||||
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
|
||||
:param temperature_for_confidence: The divisor that is used to scale logits to calibrate confidence scores
|
||||
:param use_confidence_scores_for_ranking: Whether to sort answers by confidence score (normalized between 0 and 1) or by standard score (unbounded)(default).
|
||||
:param use_confidence_scores_for_ranking: Whether to sort answers by confidence score (normalized between 0 and 1)(default) or by standard score (unbounded).
|
||||
:param use_no_answer_legacy_confidence: Whether to use the legacy confidence definition for no_answer: difference between the best overall answer confidence and the no_answer gap confidence.
|
||||
Otherwise we use the no_answer score normalized to a range of [0,1] by an expit function (default).
|
||||
"""
|
||||
super(QuestionAnsweringHead, self).__init__()
|
||||
if len(kwargs) > 0:
|
||||
@ -279,6 +283,7 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
self.generate_config()
|
||||
self.temperature_for_confidence = nn.Parameter(torch.ones(1) * temperature_for_confidence)
|
||||
self.use_confidence_scores_for_ranking = use_confidence_scores_for_ranking
|
||||
self.use_no_answer_legacy_confidence = use_no_answer_legacy_confidence
|
||||
|
||||
@classmethod
|
||||
def load(cls, pretrained_model_name_or_path: Union[str, Path], revision: Optional[str] = None, **kwargs): # type: ignore
|
||||
@ -520,7 +525,11 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
if self.duplicate_filtering > -1 and (start_idx in start_idx_candidates or end_idx in end_idx_candidates):
|
||||
continue
|
||||
score = start_end_matrix[start_idx, end_idx].item()
|
||||
confidence = (start_matrix_softmax_start[start_idx].item() + end_matrix_softmax_end[end_idx].item()) / 2
|
||||
confidence = (
|
||||
(start_matrix_softmax_start[start_idx].item() + end_matrix_softmax_end[end_idx].item()) / 2
|
||||
if score > -500
|
||||
else np.exp(score / 10) # disqualify answers according to scores in logits_to_preds()
|
||||
)
|
||||
top_candidates.append(
|
||||
QACandidate(
|
||||
offset_answer_start=start_idx,
|
||||
@ -795,7 +804,9 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
aggregation_level="document",
|
||||
passage_id=None,
|
||||
n_passages_in_doc=n_samples,
|
||||
confidence=best_overall_positive_confidence - no_ans_gap_confidence,
|
||||
confidence=best_overall_positive_confidence - no_ans_gap_confidence
|
||||
if self.use_no_answer_legacy_confidence
|
||||
else float(expit(np.asarray(best_overall_positive_score - no_ans_gap) / 8)),
|
||||
)
|
||||
|
||||
# Add no answer to positive answers, sort the order and return the n_best
|
||||
|
@ -143,14 +143,9 @@ class FARMReader(BaseReader):
|
||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
||||
self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer
|
||||
try:
|
||||
self.inferencer.model.prediction_heads[0].n_best_per_sample = top_k_per_sample
|
||||
except:
|
||||
logger.warning("Could not set `top_k_per_sample` in FARM. Please update FARM version.")
|
||||
try:
|
||||
self.inferencer.model.prediction_heads[0].duplicate_filtering = duplicate_filtering
|
||||
except:
|
||||
logger.warning("Could not set `duplicate_filtering` in FARM. Please update FARM version.")
|
||||
self.inferencer.model.prediction_heads[0].use_confidence_scores_for_ranking = use_confidence_scores
|
||||
self.max_seq_len = max_seq_len
|
||||
self.progress_bar = progress_bar
|
||||
self.use_confidence_scores = use_confidence_scores
|
||||
@ -846,6 +841,7 @@ class FARMReader(BaseReader):
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold-label",
|
||||
calibrate_conf_scores: bool = False,
|
||||
use_no_answer_legacy_confidence=False,
|
||||
):
|
||||
"""
|
||||
Performs evaluation on evaluation documents in the DocumentStore.
|
||||
@ -862,6 +858,8 @@ class FARMReader(BaseReader):
|
||||
:param doc_index: Index/Table name where documents that are used for evaluation are stored
|
||||
:param label_origin: Field name where the gold labels are stored
|
||||
:param calibrate_conf_scores: Whether to calibrate the temperature for temperature scaling of the confidence scores
|
||||
:param use_no_answer_legacy_confidence: Whether to use the legacy confidence definition for no_answer: difference between the best overall answer confidence and the no_answer gap confidence.
|
||||
Otherwise we use the no_answer score normalized to a range of [0,1] by an expit function (default).
|
||||
"""
|
||||
if device is None:
|
||||
device = self.devices[0]
|
||||
@ -968,7 +966,12 @@ class FARMReader(BaseReader):
|
||||
|
||||
evaluator = Evaluator(data_loader=data_loader, tasks=self.inferencer.processor.tasks, device=device)
|
||||
|
||||
eval_results = evaluator.eval(self.inferencer.model, calibrate_conf_scores=calibrate_conf_scores)
|
||||
eval_results = evaluator.eval(
|
||||
self.inferencer.model,
|
||||
calibrate_conf_scores=calibrate_conf_scores,
|
||||
use_confidence_scores_for_ranking=self.use_confidence_scores,
|
||||
use_no_answer_legacy_confidence=use_no_answer_legacy_confidence,
|
||||
)
|
||||
toc = perf_counter()
|
||||
reader_time = toc - tic
|
||||
results = {
|
||||
|
@ -128,7 +128,8 @@ def test_add_eval_data(document_store, batch_size):
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus1"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_eval_reader(reader, document_store: BaseDocumentStore):
|
||||
@pytest.mark.parametrize("use_confidence_scores", [True, False])
|
||||
def test_eval_reader(reader, document_store: BaseDocumentStore, use_confidence_scores):
|
||||
# add eval data (SQUAD format)
|
||||
document_store.add_eval_data(
|
||||
filename=SAMPLES_PATH / "squad" / "tiny.json",
|
||||
@ -136,6 +137,9 @@ def test_eval_reader(reader, document_store: BaseDocumentStore):
|
||||
label_index="haystack_test_feedback",
|
||||
)
|
||||
assert document_store.get_document_count(index="haystack_test_eval_document") == 2
|
||||
|
||||
reader.use_confidence_scores = use_confidence_scores
|
||||
|
||||
# eval reader
|
||||
reader_eval_results = reader.eval(
|
||||
document_store=document_store,
|
||||
@ -143,8 +147,13 @@ def test_eval_reader(reader, document_store: BaseDocumentStore):
|
||||
doc_index="haystack_test_eval_document",
|
||||
device="cpu",
|
||||
)
|
||||
assert reader_eval_results["f1"] > 66.65
|
||||
assert reader_eval_results["f1"] < 66.67
|
||||
|
||||
if use_confidence_scores:
|
||||
assert reader_eval_results["f1"] == 50
|
||||
assert reader_eval_results["EM"] == 50
|
||||
assert reader_eval_results["top_n_accuracy"] == 100.0
|
||||
else:
|
||||
assert 66.67 > reader_eval_results["f1"] > 66.65
|
||||
assert reader_eval_results["EM"] == 50
|
||||
assert reader_eval_results["top_n_accuracy"] == 100.0
|
||||
|
||||
|
@ -64,17 +64,8 @@ def test_span_inference_result_ranking_by_confidence(bert_base_squad2, caplog=No
|
||||
questions=Question("Who counted the game among the best ever made?", uid="best_id_ever"),
|
||||
)
|
||||
]
|
||||
result = bert_base_squad2.inference_from_objects(obj_input, return_json=False)[0]
|
||||
|
||||
# by default, result is sorted by score and not by confidence
|
||||
assert all(result.prediction[i].score >= result.prediction[i + 1].score for i in range(len(result.prediction) - 1))
|
||||
assert not all(
|
||||
result.prediction[i].confidence >= result.prediction[i + 1].confidence
|
||||
for i in range(len(result.prediction) - 1)
|
||||
)
|
||||
|
||||
# ranking can be adjusted so that result is sorted by confidence
|
||||
bert_base_squad2.model.prediction_heads[0].use_confidence_scores_for_ranking = True
|
||||
# by default, result is sorted by confidence and not by score
|
||||
result_ranked_by_confidence = bert_base_squad2.inference_from_objects(obj_input, return_json=False)[0]
|
||||
assert all(
|
||||
result_ranked_by_confidence.prediction[i].confidence >= result_ranked_by_confidence.prediction[i + 1].confidence
|
||||
@ -85,6 +76,18 @@ def test_span_inference_result_ranking_by_confidence(bert_base_squad2, caplog=No
|
||||
for i in range(len(result_ranked_by_confidence.prediction) - 1)
|
||||
)
|
||||
|
||||
# ranking can be adjusted so that result is sorted by score
|
||||
bert_base_squad2.model.prediction_heads[0].use_confidence_scores_for_ranking = False
|
||||
result_ranked_by_score = bert_base_squad2.inference_from_objects(obj_input, return_json=False)[0]
|
||||
assert all(
|
||||
result_ranked_by_score.prediction[i].score >= result_ranked_by_score.prediction[i + 1].score
|
||||
for i in range(len(result_ranked_by_score.prediction) - 1)
|
||||
)
|
||||
assert not all(
|
||||
result_ranked_by_score.prediction[i].confidence >= result_ranked_by_score.prediction[i + 1].confidence
|
||||
for i in range(len(result_ranked_by_score.prediction) - 1)
|
||||
)
|
||||
|
||||
|
||||
def test_inference_objs(span_inference_result, caplog=None):
|
||||
if caplog:
|
||||
@ -226,6 +229,7 @@ def test_no_duplicate_answer_filtering(bert_base_squad2):
|
||||
bert_base_squad2.model.prediction_heads[0].n_best = 5
|
||||
bert_base_squad2.model.prediction_heads[0].n_best_per_sample = 5
|
||||
bert_base_squad2.model.prediction_heads[0].duplicate_filtering = -1
|
||||
bert_base_squad2.model.prediction_heads[0].no_ans_boost = -100.0
|
||||
|
||||
result = bert_base_squad2.inference_from_dicts(dicts=qa_input)
|
||||
offset_answer_starts = []
|
||||
|
@ -1,6 +1,7 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
from haystack.modeling.data_handler.inputs import QAInput, Question
|
||||
|
||||
from haystack.schema import Document, Answer
|
||||
from haystack.nodes.reader.base import BaseReader
|
||||
@ -169,3 +170,38 @@ def test_farm_reader_update_params(test_docs_xs):
|
||||
with pytest.raises(Exception):
|
||||
reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=99, doc_stride=128)
|
||||
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_confidence_scores", [True, False])
|
||||
def test_farm_reader_uses_same_sorting_as_QAPredictionHead(use_confidence_scores):
|
||||
reader = FARMReader(
|
||||
model_name_or_path="deepset/roberta-base-squad2",
|
||||
use_gpu=False,
|
||||
num_processes=0,
|
||||
return_no_answer=True,
|
||||
use_confidence_scores=use_confidence_scores,
|
||||
)
|
||||
|
||||
text = """Beer is one of the oldest[1][2][3] and most widely consumed[4] alcoholic drinks in the world, and the third most popular drink overall after water and tea.[5] It is produced by the brewing and fermentation of starches, mainly derived from cereal grains—most commonly from malted barley, though wheat, maize (corn), rice, and oats are also used. During the brewing process, fermentation of the starch sugars in the wort produces ethanol and carbonation in the resulting beer.[6] Most modern beer is brewed with hops, which add bitterness and other flavours and act as a natural preservative and stabilizing agent. Other flavouring agents such as gruit, herbs, or fruits may be included or used instead of hops. In commercial brewing, the natural carbonation effect is often removed during processing and replaced with forced carbonation.[7]
|
||||
Some of humanity's earliest known writings refer to the production and distribution of beer: the Code of Hammurabi included laws regulating beer and beer parlours,[8] and "The Hymn to Ninkasi", a prayer to the Mesopotamian goddess of beer, served as both a prayer and as a method of remembering the recipe for beer in a culture with few literate people.[9][10]
|
||||
Beer is distributed in bottles and cans and is also commonly available on draught, particularly in pubs and bars. The brewing industry is a global business, consisting of several dominant multinational companies and many thousands of smaller producers ranging from brewpubs to regional breweries. The strength of modern beer is usually around 4% to 6% alcohol by volume (ABV), although it may vary between 0.5% and 20%, with some breweries creating examples of 40% ABV and above.[11]
|
||||
Beer forms part of the culture of many nations and is associated with social traditions such as beer festivals, as well as a rich pub culture involving activities like pub crawling, pub quizzes and pub games.
|
||||
When beer is distilled, the resulting liquor is a form of whisky.[12]
|
||||
"""
|
||||
|
||||
docs = [Document(text)]
|
||||
query = "What is the third most popular drink?"
|
||||
|
||||
reader_predictions = reader.predict(query=query, documents=docs, top_k=5)
|
||||
|
||||
farm_input = [QAInput(doc_text=d.content, questions=Question(query)) for d in docs]
|
||||
inferencer_predictions = reader.inferencer.inference_from_objects(farm_input, return_json=False)
|
||||
|
||||
for answer, qa_cand in zip(reader_predictions["answers"], inferencer_predictions[0].prediction):
|
||||
assert answer.answer == ("" if qa_cand.answer_type == "no_answer" else qa_cand.answer)
|
||||
assert answer.offsets_in_document[0].start == qa_cand.offset_answer_start
|
||||
assert answer.offsets_in_document[0].end == qa_cand.offset_answer_end
|
||||
if use_confidence_scores:
|
||||
assert answer.score == qa_cand.confidence
|
||||
else:
|
||||
assert answer.score == qa_cand.score
|
||||
|
Loading…
x
Reference in New Issue
Block a user