mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-03 13:23:38 +00:00
Added support for unanswerable questions in TransformersReader (#258)
* Added support for unanswerable questions in TransformersReader Co-authored-by: Antonio Lanza <anotniolanza1996@gmail.com>
This commit is contained in:
parent
f0d901a374
commit
b55de6f70a
@ -23,6 +23,7 @@ class TransformersReader(BaseReader):
|
|||||||
context_window_size: int = 30,
|
context_window_size: int = 30,
|
||||||
use_gpu: int = 0,
|
use_gpu: int = 0,
|
||||||
n_best_per_passage: int = 2,
|
n_best_per_passage: int = 2,
|
||||||
|
no_answer: bool = True
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load a QA model from Transformers.
|
Load a QA model from Transformers.
|
||||||
@ -39,11 +40,16 @@ class TransformersReader(BaseReader):
|
|||||||
The context usually helps users to understand if the answer really makes sense.
|
The context usually helps users to understand if the answer really makes sense.
|
||||||
:param use_gpu: < 0 -> use cpu
|
:param use_gpu: < 0 -> use cpu
|
||||||
>= 0 -> ordinal of the gpu to use
|
>= 0 -> ordinal of the gpu to use
|
||||||
|
:param n_best_per_passage: num of best answers to take into account for each passage
|
||||||
|
:param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
|
||||||
|
False -> otherwise
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
|
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
|
||||||
self.context_window_size = context_window_size
|
self.context_window_size = context_window_size
|
||||||
self.n_best_per_passage = n_best_per_passage
|
self.n_best_per_passage = n_best_per_passage
|
||||||
#TODO param to modify bias for no_answer
|
self.no_answer = no_answer
|
||||||
|
|
||||||
# TODO context_window_size behaviour different from behavior in FARMReader
|
# TODO context_window_size behaviour different from behavior in FARMReader
|
||||||
|
|
||||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||||
@ -76,25 +82,24 @@ class TransformersReader(BaseReader):
|
|||||||
answers = []
|
answers = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
query = {"context": doc.text, "question": question}
|
query = {"context": doc.text, "question": question}
|
||||||
predictions = self.model(query, topk=self.n_best_per_passage)
|
predictions = self.model(query, topk=self.n_best_per_passage,handle_impossible_answer=self.no_answer)
|
||||||
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
|
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
|
||||||
if type(predictions) == dict:
|
if type(predictions) == dict:
|
||||||
predictions = [predictions]
|
predictions = [predictions]
|
||||||
# assemble and format all answers
|
# assemble and format all answers
|
||||||
for pred in predictions:
|
for pred in predictions:
|
||||||
if pred["answer"]:
|
context_start = max(0, pred["start"] - self.context_window_size)
|
||||||
context_start = max(0, pred["start"] - self.context_window_size)
|
context_end = min(len(doc.text), pred["end"] + self.context_window_size)
|
||||||
context_end = min(len(doc.text), pred["end"] + self.context_window_size)
|
answers.append({
|
||||||
answers.append({
|
"answer": pred["answer"],
|
||||||
"answer": pred["answer"],
|
"context": doc.text[context_start:context_end],
|
||||||
"context": doc.text[context_start:context_end],
|
"offset_start": pred["start"],
|
||||||
"offset_start": pred["start"],
|
"offset_end": pred["end"],
|
||||||
"offset_end": pred["end"],
|
"probability": pred["score"],
|
||||||
"probability": pred["score"],
|
"score": None,
|
||||||
"score": None,
|
"document_id": doc.id,
|
||||||
"document_id": doc.id,
|
"meta": doc.meta
|
||||||
"meta": doc.meta
|
})
|
||||||
})
|
|
||||||
|
|
||||||
# sort answers by their `probability` and select top-k
|
# sort answers by their `probability` and select top-k
|
||||||
answers = sorted(
|
answers = sorted(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user