Merge pull request #26 from deepset-ai/no_answer

Add no answer handling
This commit is contained in:
Timo Moeller 2020-02-24 16:27:59 +01:00 committed by GitHub
commit c0910c82c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 57 deletions

View File

@ -46,7 +46,7 @@ class Finder:
# 3) Apply reader to get granular answer(s) # 3) Apply reader to get granular answer(s)
logger.info(f"Applying the reader now to look for the answer in detail ...") logger.info(f"Applying the reader now to look for the answer in detail ...")
results = self.reader.predict(question=question, results = self.reader.predict(question=question,
paragrahps=paragraphs, paragraphs=paragraphs,
meta_data_paragraphs=meta_data, meta_data_paragraphs=meta_data,
top_k=top_k_reader) top_k=top_k_reader)

View File

@ -27,10 +27,10 @@ class FARMReader:
self, self,
model_name_or_path, model_name_or_path,
context_window_size=30, context_window_size=30,
no_ans_boost=-100, batch_size=50,
batch_size=16,
use_gpu=True, use_gpu=True,
n_candidates_per_passage=2): no_ans_boost=None,
n_candidates_per_paragraph=1):
""" """
:param model_name_or_path: directory of a saved model or the name of a public model: :param model_name_or_path: directory of a saved model or the name of a public model:
- 'bert-base-cased' - 'bert-base-cased'
@ -40,25 +40,31 @@ class FARMReader:
.... ....
See https://huggingface.co/models for full list of available models. See https://huggingface.co/models for full list of available models.
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
:param no_ans_boost: How much the no_answer logit is boosted/increased.
The higher the value, the more likely a "no answer possible" is returned by the model
:param batch_size: Number of samples the model receives in one batch for inference :param batch_size: Number of samples the model receives in one batch for inference
Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used.
:param use_gpu: Whether to use GPU (if available) :param use_gpu: Whether to use GPU (if available)
:param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). :param no_ans_boost: How much the no_answer logit is boosted/increased.
Note: This is not the number of "final answers" you will receive Possible values: None (default) = disable returning "no answer" predictions
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that) Negative = lower chance of "no answer" being predicted
# TODO adjust farm. n_cand = 2 returns no answer + highest positive answer Positive = increase chance of "no answer"
# should return no answer + 2 best positive answers :param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
# drawback: answers from a single paragraph might be very similar in text and score Note: - This is not the number of "final answers" you will receive
# we need to have more varied answers (by excluding overlapping answers?) (see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
- FARM includes no_answer in the sorted list of predictions
""" """
if no_ans_boost is None:
no_ans_boost = 0
self.return_no_answers = False
else:
self.return_no_answers = True
self.n_candidates_per_paragraph = n_candidates_per_paragraph
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_boost # TODO adjust naming and concept in FARM self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
self.no_ans_boost = no_ans_boost self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer
self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5, use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
@ -146,7 +152,7 @@ class FARMReader:
self.inferencer.model.save(directory) self.inferencer.model.save(directory)
self.inferencer.processor.save(directory) self.inferencer.processor.save(directory)
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1): def predict(self, question, paragraphs, meta_data_paragraphs=None, top_k=None, max_processes=1):
""" """
Use loaded QA model to find answers for a question in the supplied paragraphs. Use loaded QA model to find answers for a question in the supplied paragraphs.
@ -160,7 +166,7 @@ class FARMReader:
'offset_answer_end': 154, 'offset_answer_end': 154,
'probability': 0.9787139466668613, 'probability': 0.9787139466668613,
'score': None, 'score': None,
'document_id': None 'document_id': '1337'
}, },
... ...
] ]
@ -176,19 +182,19 @@ class FARMReader:
""" """
if meta_data_paragraphs is None: if meta_data_paragraphs is None:
meta_data_paragraphs = len(paragrahps) * [None] meta_data_paragraphs = len(paragraphs) * [None]
assert len(paragrahps) == len(meta_data_paragraphs) assert len(paragraphs) == len(meta_data_paragraphs)
# convert input to FARM format # convert input to FARM format
input_dicts = [] input_dicts = []
for paragraph, meta_data in zip(paragrahps, meta_data_paragraphs): for paragraph, meta_data in zip(paragraphs, meta_data_paragraphs):
cur = {"text": paragraph, cur = {"text": paragraph,
"questions": [question], "questions": [question],
"document_id": meta_data["document_id"] "document_id": meta_data["document_id"]
} }
input_dicts.append(cur) input_dicts.append(cur)
# get answers from QA model (Default: top 5 per input paragraph) # get answers from QA model
predictions = self.inferencer.inference_from_dicts( predictions = self.inferencer.inference_from_dicts(
dicts=input_dicts, rest_api_schema=True, max_processes=max_processes dicts=input_dicts, rest_api_schema=True, max_processes=max_processes
) )
@ -200,6 +206,8 @@ class FARMReader:
no_ans_gaps = [] no_ans_gaps = []
best_score_answer = 0 best_score_answer = 0
for pred in predictions: for pred in predictions:
answers_per_paragraph = []
no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"])
for a in pred["predictions"][0]["answers"]: for a in pred["predictions"][0]["answers"]:
# skip "no answers" here # skip "no answers" here
if a["answer"]: if a["answer"]:
@ -210,38 +218,17 @@ class FARMReader:
"offset_start": a["offset_answer_start"] - a["offset_context_start"], "offset_start": a["offset_answer_start"] - a["offset_context_start"],
"offset_end": a["offset_answer_end"] - a["offset_context_start"], "offset_end": a["offset_answer_end"] - a["offset_context_start"],
"document_id": a["document_id"]} "document_id": a["document_id"]}
answers.append(cur) answers_per_paragraph.append(cur)
no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"])
if a["score"] > best_score_answer: if a["score"] > best_score_answer:
best_score_answer = a["score"] best_score_answer = a["score"]
# only take n best candidates. Answers coming back from FARM are sorted with decreasing relevance.
answers += answers_per_paragraph[:self.n_candidates_per_paragraph]
# adjust no_ans_gaps # Calculate the score for predicting "no answer", relative to our best positive answer score
no_ans_gaps = np.array(no_ans_gaps) no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps,best_score_answer)
no_ans_gaps_adjusted = no_ans_gaps + self.no_ans_boost if self.return_no_answers:
answers.append(no_ans_prediction)
# We want to heuristically rank how likely or unlikely the "no answer" option is.
# case: all documents return no answer, then all no_ans_gaps are positive
if np.sum(no_ans_gaps_adjusted < 0) == 0:
# to rank we add the smallest no_ans_gap (a document where an answer would be nearly as likely as the no anser)
# to the highest answer score we found
no_ans_score = best_score_answer + min(no_ans_gaps_adjusted)
# case: documents where answers are preferred over no answer, the no_ans_gap is negative
else:
# the lowest (highest negative) no_ans_gap would be needed as positive no_ans_boost for the
# model to return "no answer" on all documents
# we subtract this value from the best answer score to rank our "no answer" option
# magically this is the same equation as used for the case above : )
no_ans_score = best_score_answer + min(no_ans_gaps_adjusted)
cur = {"answer": "[computer says no answer is likely]",
"score": no_ans_score,
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
"context": "",
"offset_start": 0,
"offset_end": 0,
"document_id": None}
answers.append(cur)
# sort answers by their `probability` and select top-k # sort answers by their `probability` and select top-k
answers = sorted( answers = sorted(
@ -249,7 +236,32 @@ class FARMReader:
) )
answers = answers[:top_k] answers = answers[:top_k]
result = {"question": question, result = {"question": question,
"adjust_no_ans_boost": -min(no_ans_gaps_adjusted), "no_ans_gap": max_no_ans_gap,
"answers": answers} "answers": answers}
return result return result
@staticmethod
def _calc_no_answer(no_ans_gaps,best_score_answer):
# "no answer" scores and positive answers scores are difficult to compare, because
# + a positive answer score is related to one specific document
# - a "no answer" score is related to all input documents
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
# the most significant difference between scores.
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
# No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions
no_ans_gaps = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gaps)
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # all passages "no answer" as top score
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
else: # case: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap
no_ans_prediction = {"answer": None,
"score": no_ans_score,
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
"context": None,
"offset_start": 0,
"offset_end": 0,
"document_id": None}
return no_ans_prediction, max_no_ans_gap

View File

@ -35,10 +35,10 @@ retriever = TfidfRetriever(document_store=document_store)
# A reader scans the text chunks in detail and extracts the k best answers # A reader scans the text chunks in detail and extracts the k best answers
# Reader use more powerful but slower deep learning models # Reader use more powerful but slower deep learning models
# You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models) # You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models)
# here: a medium sized BERT QA model trained via FARM on Squad 2.0 # here: a medium sized BERT QA model trained via FARM on Squad 2.0
# You can adjust the model to return "no answer possible" with the no_ans_boost. Higher values mean the model prefers "no answer possible" # You can adjust the model to return "no answer possible" with the no_ans_boost. Higher values mean the model prefers "no answer possible"
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False, no_ans_boost=0) reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False)
# OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers) # OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers)
# reader = TransformersReader(model="distilbert-base-uncased-distilled-squad", tokenizer="distilbert-base-uncased", use_gpu=-1) # reader = TransformersReader(model="distilbert-base-uncased-distilled-squad", tokenizer="distilbert-base-uncased", use_gpu=-1)
@ -51,7 +51,10 @@ finder = Finder(reader, retriever)
# The higher top_k_retriever, the better (but also the slower) your answers. # The higher top_k_retriever, the better (but also the slower) your answers.
prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5) prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5)
#prediction = finder.get_answers(question="Who is the daughter of Arya Stark?", top_k_reader=5) # impossible question test # to test impossible questions we need a large QA model, e.g. deepset/bert-large-uncased-whole-word-masking-squad2
# and we need to enable returning "no answer possible" by setting no_ans_boost=X in FARMReader
# prediction = finder.get_answers(question="Who is the first daughter of Arya Stark?", top_k_retriever=10, top_k_reader=5)
#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5) #prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5)
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5) #prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)