Simplify no ans handling, disable no ans + sorting in private function

This commit is contained in:
timoeller 2020-02-24 16:15:06 +01:00
parent 0f5b61d20a
commit f681026a56
2 changed files with 84 additions and 71 deletions

View File

@ -23,42 +23,6 @@ class FARMReader:
- fine-tune the model on QA data via train()
"""
def __init__(
self,
model_name_or_path,
context_window_size=30,
no_ans_boost=-100,
batch_size=16,
use_gpu=True,
n_candidates_per_passage=2):
"""
:param model_name_or_path: directory of a saved model or the name of a public model:
- 'bert-base-cased'
- 'deepset/bert-base-cased-squad2'
- 'deepset/bert-base-cased-squad2'
- 'distilbert-base-uncased-distilled-squad'
....
See https://huggingface.co/models for full list of available models.
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
:param no_ans_boost: How much the no_answer logit is boosted/increased.
The higher the value, the more likely a "no answer possible" is returned by the model
:param batch_size: Number of samples the model receives in one batch for inference
:param use_gpu: Whether to use GPU (if available)
:param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
Note: - This is not the number of "final answers" you will receive
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
- FARM includes no_answer in the sorted list of predictions
"""
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
# TODO adjust terminology: a haystack passage is a FARM document (which gets devided into FARM passages depending on document length and max_seq_len)
self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
@ -140,6 +104,49 @@ class FARMReader:
self.inferencer.model = trainer.train()
self.save(save_dir)
def __init__(
self,
model_name_or_path,
context_window_size=30,
batch_size=50,
use_gpu=True,
no_ans_boost=None,
n_candidates_per_paragraph=1):
"""
:param model_name_or_path: directory of a saved model or the name of a public model:
- 'bert-base-cased'
- 'deepset/bert-base-cased-squad2'
- 'deepset/bert-base-cased-squad2'
- 'distilbert-base-uncased-distilled-squad'
....
See https://huggingface.co/models for full list of available models.
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
:param batch_size: Number of samples the model receives in one batch for inference
Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used.
:param use_gpu: Whether to use GPU (if available)
:param no_ans_boost: How much the no_answer logit is boosted/increased.
Possible values: None (default) = disable returning "no answer" predictions
Negative = lower chance of "no answer" being predicted
Positive = increase chance of "no answer"
:param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
Note: - This is not the number of "final answers" you will receive
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
- FARM includes no_answer in the sorted list of predictions
"""
if no_ans_boost is None:
no_ans_boost = 0
self.return_no_answers = False
else:
self.return_no_answers = True
self.n_candidates_per_paragraph = n_candidates_per_paragraph
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer
def save(self, directory):
logger.info(f"Saving reader model to {directory}")
self.inferencer.model.save(directory)
@ -159,7 +166,7 @@ class FARMReader:
'offset_answer_end': 154,
'probability': 0.9787139466668613,
'score': None,
'document_id': None
'document_id': '1337'
},
...
]
@ -199,12 +206,11 @@ class FARMReader:
no_ans_gaps = []
best_score_answer = 0
for pred in predictions:
positive_found = False
answers_per_paragraph = []
no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"])
for a in pred["predictions"][0]["answers"]:
# skip "no answers" here
# For now we only take one prediction from each passage
# TODO use more predictions per passage when setting n_candidates_per_passage + make FARM predictions more varied
if(not positive_found and a["answer"]):
if a["answer"]:
cur = {"answer": a["answer"],
"score": a["score"],
"probability": float(expit(np.asarray([a["score"]]) / 8)), #just a pseudo prob for now
@ -212,35 +218,17 @@ class FARMReader:
"offset_start": a["offset_answer_start"] - a["offset_context_start"],
"offset_end": a["offset_answer_end"] - a["offset_context_start"],
"document_id": a["document_id"]}
answers.append(cur)
no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"])
answers_per_paragraph.append(cur)
if a["score"] > best_score_answer:
best_score_answer = a["score"]
positive_found = True
answers += answers_per_paragraph[:self.n_candidates_per_paragraph]
# "no answer" scores and positive answers scores are difficult to compare, because
# + a positive answer score is related to one specific document
# - a "no answer" score is related to all input documents
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
# the most significant difference between scores.
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
# No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions
no_ans_gaps = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gaps)
if(np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # all passages "no answer" as top score
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
else: # case: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap
cur = {"answer": None,
"score": no_ans_score,
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
"context": None,
"offset_start": 0,
"offset_end": 0,
"document_id": None}
answers.append(cur)
# Calculate the score for predicting "no answer", relative to our best positive answer score
no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps,best_score_answer)
if self.return_no_answers:
answers.append(no_ans_prediction)
# sort answers by their `probability` and select top-k
answers = sorted(
@ -251,4 +239,29 @@ class FARMReader:
"no_ans_gap": max_no_ans_gap,
"answers": answers}
return result
return result
@staticmethod
def _calc_no_answer(no_ans_gaps,best_score_answer):
# "no answer" scores and positive answers scores are difficult to compare, because
# + a positive answer score is related to one specific document
# - a "no answer" score is related to all input documents
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
# the most significant difference between scores.
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
# No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions
no_ans_gaps = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gaps)
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # all passages "no answer" as top score
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
else: # case: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap
no_ans_prediction = {"answer": None,
"score": no_ans_score,
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
"context": None,
"offset_start": 0,
"offset_end": 0,
"document_id": None}
return no_ans_prediction, max_no_ans_gap

View File

@ -35,10 +35,10 @@ retriever = TfidfRetriever(document_store=document_store)
# A reader scans the text chunks in detail and extracts the k best answers
# Reader use more powerful but slower deep learning models
# You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models)
# You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models)
# here: a medium sized BERT QA model trained via FARM on Squad 2.0
# You can adjust the model to return "no answer possible" with the no_ans_boost. Higher values mean the model prefers "no answer possible"
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False, no_ans_boost=0)
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False)
# OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers)
# reader = TransformersReader(model="distilbert-base-uncased-distilled-squad", tokenizer="distilbert-base-uncased", use_gpu=-1)
@ -52,8 +52,8 @@ finder = Finder(reader, retriever)
prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5)
# to test impossible questions we need a large QA model, e.g. deepset/bert-large-uncased-whole-word-masking-squad2
#prediction = finder.get_answers(question="Who is the first daughter of Arya Stark?", top_k_retriever=10, top_k_reader=5)
# and we need to enable returning "no answer possible" by setting no_ans_boost=X in FARMReader
# prediction = finder.get_answers(question="Who is the first daughter of Arya Stark?", top_k_retriever=10, top_k_reader=5)
#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5)
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)