Add "no answer" aggregation to Transformersreader (#259)

* Add no answer aggregation

* Change to covariant type annotation

* Remove n_best_per_passage from transformersreader
This commit is contained in:
Timo Moeller 2020-08-06 17:32:55 +02:00 committed by GitHub
parent 89dcfed619
commit d9e8b522a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 46 deletions

View File

@ -1,5 +1,7 @@
import numpy as np
from scipy.special import expit
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional, Sequence
from haystack.database.base import Document from haystack.database.base import Document
@ -9,3 +11,30 @@ class BaseReader(ABC):
@abstractmethod @abstractmethod
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
pass pass
@staticmethod
def _calc_no_answer(no_ans_gaps: Sequence[float], best_score_answer: float):
# "no answer" scores and positive answers scores are difficult to compare, because
# + a positive answer score is related to one specific document
# - a "no answer" score is related to all input documents
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
# the most significant difference between scores.
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
# No_ans_gap is a list of this most significant difference per document
no_ans_gaps = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gaps)
# all passages "no answer" as top score
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # type: ignore
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
else: # case: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap
no_ans_prediction = {"answer": None,
"score": no_ans_score,
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
"context": None,
"offset_start": 0,
"offset_end": 0,
"document_id": None,
"meta": None,}
return no_ans_prediction, max_no_ans_gap

View File

@ -40,7 +40,7 @@ class FARMReader(BaseReader):
context_window_size: int = 150, context_window_size: int = 150,
batch_size: int = 50, batch_size: int = 50,
use_gpu: bool = True, use_gpu: bool = True,
no_ans_boost: Optional[int] = None, no_ans_boost: Optional[float] = None,
top_k_per_candidate: int = 3, top_k_per_candidate: int = 3,
top_k_per_sample: int = 1, top_k_per_sample: int = 1,
num_processes: Optional[int] = None, num_processes: Optional[int] = None,
@ -446,32 +446,6 @@ class FARMReader(BaseReader):
return False return False
@staticmethod
def _calc_no_answer(no_ans_gaps: List[float], best_score_answer: float):
# "no answer" scores and positive answers scores are difficult to compare, because
# + a positive answer score is related to one specific document
# - a "no answer" score is related to all input documents
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
# the most significant difference between scores.
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
# No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions
no_ans_gaps = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gaps)
# all passages "no answer" as top score
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # type: ignore
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
else: # case: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap
no_ans_prediction = {"answer": None,
"score": no_ans_score,
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
"context": None,
"offset_start": 0,
"offset_end": 0,
"document_id": None}
return no_ans_prediction, max_no_ans_gap
def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None): def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None):
documents = [] documents = []
for text in texts: for text in texts:

View File

@ -22,7 +22,7 @@ class TransformersReader(BaseReader):
tokenizer: str = "distilbert-base-uncased", tokenizer: str = "distilbert-base-uncased",
context_window_size: int = 30, context_window_size: int = 30,
use_gpu: int = 0, use_gpu: int = 0,
n_best_per_passage: int = 2, top_k_per_candidate: int = 4,
no_answer: bool = True no_answer: bool = True
): ):
""" """
@ -40,14 +40,17 @@ class TransformersReader(BaseReader):
The context usually helps users to understand if the answer really makes sense. The context usually helps users to understand if the answer really makes sense.
:param use_gpu: < 0 -> use cpu :param use_gpu: < 0 -> use cpu
>= 0 -> ordinal of the gpu to use >= 0 -> ordinal of the gpu to use
:param n_best_per_passage: num of best answers to take into account for each passage :param top_k_per_candidate: How many answers to extract for each candidate doc that is coming from the retriever (might be a long text).
Note: - This is not the number of "final answers" you will receive
(see `top_k` in TransformersReader.predict() or Finder.get_answers() for that)
- Can includes no_answer in the sorted list of predictions
:param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question) :param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
False -> otherwise False -> otherwise
""" """
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu) self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
self.context_window_size = context_window_size self.context_window_size = context_window_size
self.n_best_per_passage = n_best_per_passage self.top_k_per_candidate = top_k_per_candidate
self.no_answer = no_answer self.no_answer = no_answer
# TODO context_window_size behaviour different from behavior in FARMReader # TODO context_window_size behaviour different from behavior in FARMReader
@ -80,14 +83,24 @@ class TransformersReader(BaseReader):
""" """
# get top-answers for each candidate passage # get top-answers for each candidate passage
answers = [] answers = []
no_ans_gaps = []
best_overall_score = 0
for doc in documents: for doc in documents:
query = {"context": doc.text, "question": question} query = {"context": doc.text, "question": question}
predictions = self.model(query, topk=self.n_best_per_passage,handle_impossible_answer=self.no_answer) predictions = self.model(query, topk=self.top_k_per_candidate, handle_impossible_answer=self.no_answer)
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list # for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
if type(predictions) == dict: if type(predictions) == dict:
predictions = [predictions] predictions = [predictions]
# assemble and format all answers # assemble and format all answers
best_doc_score = 0
# because we cannot ensure a "no answer" prediction coming back from transformers we initialize it here with 0
no_ans_doc_score = 0
# TODO add no answer bias on haystack side after getting "no answer" scores from transformers
for pred in predictions: for pred in predictions:
if pred["answer"]:
if pred["score"] > best_doc_score:
best_doc_score = pred["score"]
context_start = max(0, pred["start"] - self.context_window_size) context_start = max(0, pred["start"] - self.context_window_size)
context_end = min(len(doc.text), pred["end"] + self.context_window_size) context_end = min(len(doc.text), pred["end"] + self.context_window_size)
answers.append({ answers.append({
@ -100,7 +113,19 @@ class TransformersReader(BaseReader):
"document_id": doc.id, "document_id": doc.id,
"meta": doc.meta "meta": doc.meta
}) })
else:
no_ans_doc_score = pred["score"]
if best_doc_score > best_overall_score:
best_overall_score = best_doc_score
no_ans_gaps.append(no_ans_doc_score - best_doc_score)
# Calculate the score for predicting "no answer", relative to our best positive answer score
no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps, best_overall_score)
if self.no_answer:
answers.append(no_ans_prediction)
# sort answers by their `probability` and select top-k # sort answers by their `probability` and select top-k
answers = sorted( answers = sorted(
answers, key=lambda k: k["probability"], reverse=True answers, key=lambda k: k["probability"], reverse=True

View File

@ -85,7 +85,7 @@ def no_answer_reader(request):
if request.param == "transformers": if request.param == "transformers":
return TransformersReader(model="deepset/roberta-base-squad2", return TransformersReader(model="deepset/roberta-base-squad2",
tokenizer="deepset/roberta-base-squad2", tokenizer="deepset/roberta-base-squad2",
use_gpu=-1, n_best_per_passage=5) use_gpu=-1, top_k_per_candidate=5)
@pytest.fixture() @pytest.fixture()