mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-17 10:09:13 +00:00
Add "no answer" aggregation to Transformersreader (#259)
* Add no answer aggregation * Change to covariant type annotation * Remove n_best_per_passage from transformersreader
This commit is contained in:
parent
89dcfed619
commit
d9e8b522a1
@ -1,5 +1,7 @@
|
|||||||
|
import numpy as np
|
||||||
|
from scipy.special import expit
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
from haystack.database.base import Document
|
from haystack.database.base import Document
|
||||||
|
|
||||||
@ -9,3 +11,30 @@ class BaseReader(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calc_no_answer(no_ans_gaps: Sequence[float], best_score_answer: float):
|
||||||
|
# "no answer" scores and positive answers scores are difficult to compare, because
|
||||||
|
# + a positive answer score is related to one specific document
|
||||||
|
# - a "no answer" score is related to all input documents
|
||||||
|
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
|
||||||
|
# the most significant difference between scores.
|
||||||
|
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
|
||||||
|
# No_ans_gap is a list of this most significant difference per document
|
||||||
|
no_ans_gaps = np.array(no_ans_gaps)
|
||||||
|
max_no_ans_gap = np.max(no_ans_gaps)
|
||||||
|
# all passages "no answer" as top score
|
||||||
|
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # type: ignore
|
||||||
|
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
|
||||||
|
else: # case: at least one passage predicts an answer (positive no_ans_gap)
|
||||||
|
no_ans_score = best_score_answer - max_no_ans_gap
|
||||||
|
|
||||||
|
no_ans_prediction = {"answer": None,
|
||||||
|
"score": no_ans_score,
|
||||||
|
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
|
||||||
|
"context": None,
|
||||||
|
"offset_start": 0,
|
||||||
|
"offset_end": 0,
|
||||||
|
"document_id": None,
|
||||||
|
"meta": None,}
|
||||||
|
return no_ans_prediction, max_no_ans_gap
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class FARMReader(BaseReader):
|
|||||||
context_window_size: int = 150,
|
context_window_size: int = 150,
|
||||||
batch_size: int = 50,
|
batch_size: int = 50,
|
||||||
use_gpu: bool = True,
|
use_gpu: bool = True,
|
||||||
no_ans_boost: Optional[int] = None,
|
no_ans_boost: Optional[float] = None,
|
||||||
top_k_per_candidate: int = 3,
|
top_k_per_candidate: int = 3,
|
||||||
top_k_per_sample: int = 1,
|
top_k_per_sample: int = 1,
|
||||||
num_processes: Optional[int] = None,
|
num_processes: Optional[int] = None,
|
||||||
@ -446,32 +446,6 @@ class FARMReader(BaseReader):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _calc_no_answer(no_ans_gaps: List[float], best_score_answer: float):
|
|
||||||
# "no answer" scores and positive answers scores are difficult to compare, because
|
|
||||||
# + a positive answer score is related to one specific document
|
|
||||||
# - a "no answer" score is related to all input documents
|
|
||||||
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
|
|
||||||
# the most significant difference between scores.
|
|
||||||
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
|
|
||||||
# No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions
|
|
||||||
no_ans_gaps = np.array(no_ans_gaps)
|
|
||||||
max_no_ans_gap = np.max(no_ans_gaps)
|
|
||||||
# all passages "no answer" as top score
|
|
||||||
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # type: ignore
|
|
||||||
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
|
|
||||||
else: # case: at least one passage predicts an answer (positive no_ans_gap)
|
|
||||||
no_ans_score = best_score_answer - max_no_ans_gap
|
|
||||||
|
|
||||||
no_ans_prediction = {"answer": None,
|
|
||||||
"score": no_ans_score,
|
|
||||||
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
|
|
||||||
"context": None,
|
|
||||||
"offset_start": 0,
|
|
||||||
"offset_end": 0,
|
|
||||||
"document_id": None}
|
|
||||||
return no_ans_prediction, max_no_ans_gap
|
|
||||||
|
|
||||||
def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None):
|
def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None):
|
||||||
documents = []
|
documents = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class TransformersReader(BaseReader):
|
|||||||
tokenizer: str = "distilbert-base-uncased",
|
tokenizer: str = "distilbert-base-uncased",
|
||||||
context_window_size: int = 30,
|
context_window_size: int = 30,
|
||||||
use_gpu: int = 0,
|
use_gpu: int = 0,
|
||||||
n_best_per_passage: int = 2,
|
top_k_per_candidate: int = 4,
|
||||||
no_answer: bool = True
|
no_answer: bool = True
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -40,14 +40,17 @@ class TransformersReader(BaseReader):
|
|||||||
The context usually helps users to understand if the answer really makes sense.
|
The context usually helps users to understand if the answer really makes sense.
|
||||||
:param use_gpu: < 0 -> use cpu
|
:param use_gpu: < 0 -> use cpu
|
||||||
>= 0 -> ordinal of the gpu to use
|
>= 0 -> ordinal of the gpu to use
|
||||||
:param n_best_per_passage: num of best answers to take into account for each passage
|
:param top_k_per_candidate: How many answers to extract for each candidate doc that is coming from the retriever (might be a long text).
|
||||||
|
Note: - This is not the number of "final answers" you will receive
|
||||||
|
(see `top_k` in TransformersReader.predict() or Finder.get_answers() for that)
|
||||||
|
- Can includes no_answer in the sorted list of predictions
|
||||||
:param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
|
:param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
|
||||||
False -> otherwise
|
False -> otherwise
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
|
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
|
||||||
self.context_window_size = context_window_size
|
self.context_window_size = context_window_size
|
||||||
self.n_best_per_passage = n_best_per_passage
|
self.top_k_per_candidate = top_k_per_candidate
|
||||||
self.no_answer = no_answer
|
self.no_answer = no_answer
|
||||||
|
|
||||||
# TODO context_window_size behaviour different from behavior in FARMReader
|
# TODO context_window_size behaviour different from behavior in FARMReader
|
||||||
@ -80,27 +83,49 @@ class TransformersReader(BaseReader):
|
|||||||
"""
|
"""
|
||||||
# get top-answers for each candidate passage
|
# get top-answers for each candidate passage
|
||||||
answers = []
|
answers = []
|
||||||
|
no_ans_gaps = []
|
||||||
|
best_overall_score = 0
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
query = {"context": doc.text, "question": question}
|
query = {"context": doc.text, "question": question}
|
||||||
predictions = self.model(query, topk=self.n_best_per_passage,handle_impossible_answer=self.no_answer)
|
predictions = self.model(query, topk=self.top_k_per_candidate, handle_impossible_answer=self.no_answer)
|
||||||
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
|
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
|
||||||
if type(predictions) == dict:
|
if type(predictions) == dict:
|
||||||
predictions = [predictions]
|
predictions = [predictions]
|
||||||
# assemble and format all answers
|
# assemble and format all answers
|
||||||
for pred in predictions:
|
|
||||||
context_start = max(0, pred["start"] - self.context_window_size)
|
|
||||||
context_end = min(len(doc.text), pred["end"] + self.context_window_size)
|
|
||||||
answers.append({
|
|
||||||
"answer": pred["answer"],
|
|
||||||
"context": doc.text[context_start:context_end],
|
|
||||||
"offset_start": pred["start"],
|
|
||||||
"offset_end": pred["end"],
|
|
||||||
"probability": pred["score"],
|
|
||||||
"score": None,
|
|
||||||
"document_id": doc.id,
|
|
||||||
"meta": doc.meta
|
|
||||||
})
|
|
||||||
|
|
||||||
|
best_doc_score = 0
|
||||||
|
# because we cannot ensure a "no answer" prediction coming back from transformers we initialize it here with 0
|
||||||
|
no_ans_doc_score = 0
|
||||||
|
# TODO add no answer bias on haystack side after getting "no answer" scores from transformers
|
||||||
|
for pred in predictions:
|
||||||
|
if pred["answer"]:
|
||||||
|
if pred["score"] > best_doc_score:
|
||||||
|
best_doc_score = pred["score"]
|
||||||
|
context_start = max(0, pred["start"] - self.context_window_size)
|
||||||
|
context_end = min(len(doc.text), pred["end"] + self.context_window_size)
|
||||||
|
answers.append({
|
||||||
|
"answer": pred["answer"],
|
||||||
|
"context": doc.text[context_start:context_end],
|
||||||
|
"offset_start": pred["start"],
|
||||||
|
"offset_end": pred["end"],
|
||||||
|
"probability": pred["score"],
|
||||||
|
"score": None,
|
||||||
|
"document_id": doc.id,
|
||||||
|
"meta": doc.meta
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
no_ans_doc_score = pred["score"]
|
||||||
|
|
||||||
|
if best_doc_score > best_overall_score:
|
||||||
|
best_overall_score = best_doc_score
|
||||||
|
|
||||||
|
no_ans_gaps.append(no_ans_doc_score - best_doc_score)
|
||||||
|
|
||||||
|
# Calculate the score for predicting "no answer", relative to our best positive answer score
|
||||||
|
no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps, best_overall_score)
|
||||||
|
|
||||||
|
if self.no_answer:
|
||||||
|
answers.append(no_ans_prediction)
|
||||||
# sort answers by their `probability` and select top-k
|
# sort answers by their `probability` and select top-k
|
||||||
answers = sorted(
|
answers = sorted(
|
||||||
answers, key=lambda k: k["probability"], reverse=True
|
answers, key=lambda k: k["probability"], reverse=True
|
||||||
|
|||||||
@ -85,7 +85,7 @@ def no_answer_reader(request):
|
|||||||
if request.param == "transformers":
|
if request.param == "transformers":
|
||||||
return TransformersReader(model="deepset/roberta-base-squad2",
|
return TransformersReader(model="deepset/roberta-base-squad2",
|
||||||
tokenizer="deepset/roberta-base-squad2",
|
tokenizer="deepset/roberta-base-squad2",
|
||||||
use_gpu=-1, n_best_per_passage=5)
|
use_gpu=-1, top_k_per_candidate=5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user