Added support for unanswerable questions in TransformersReader (#258)

* Added support for unanswerable questions in TransformersReader

Co-authored-by: Antonio Lanza <anotniolanza1996@gmail.com>
This commit is contained in:
antoniolanza1996 2020-07-23 10:45:58 +02:00 committed by GitHub
parent f0d901a374
commit b55de6f70a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,6 +23,7 @@ class TransformersReader(BaseReader):
context_window_size: int = 30,
use_gpu: int = 0,
n_best_per_passage: int = 2,
no_answer: bool = True
):
"""
Load a QA model from Transformers.
@ -39,11 +40,16 @@ class TransformersReader(BaseReader):
The context usually helps users to understand if the answer really makes sense.
:param use_gpu: < 0 -> use cpu
>= 0 -> ordinal of the gpu to use
:param n_best_per_passage: num of best answers to take into account for each passage
:param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
False -> otherwise
"""
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
self.context_window_size = context_window_size
self.n_best_per_passage = n_best_per_passage
#TODO param to modify bias for no_answer
self.no_answer = no_answer
# TODO context_window_size behaviour different from behavior in FARMReader
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
@ -76,25 +82,24 @@ class TransformersReader(BaseReader):
answers = []
for doc in documents:
query = {"context": doc.text, "question": question}
predictions = self.model(query, topk=self.n_best_per_passage)
predictions = self.model(query, topk=self.n_best_per_passage,handle_impossible_answer=self.no_answer)
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
if type(predictions) == dict:
predictions = [predictions]
# assemble and format all answers
for pred in predictions:
if pred["answer"]:
context_start = max(0, pred["start"] - self.context_window_size)
context_end = min(len(doc.text), pred["end"] + self.context_window_size)
answers.append({
"answer": pred["answer"],
"context": doc.text[context_start:context_end],
"offset_start": pred["start"],
"offset_end": pred["end"],
"probability": pred["score"],
"score": None,
"document_id": doc.id,
"meta": doc.meta
})
context_start = max(0, pred["start"] - self.context_window_size)
context_end = min(len(doc.text), pred["end"] + self.context_window_size)
answers.append({
"answer": pred["answer"],
"context": doc.text[context_start:context_end],
"offset_start": pred["start"],
"offset_end": pred["end"],
"probability": pred["score"],
"score": None,
"document_id": doc.id,
"meta": doc.meta
})
# sort answers by their `probability` and select top-k
answers = sorted(