mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-30 16:47:19 +00:00
Merge pull request #24 from deepset-ai/no_answer
Add no_answer option to results
This commit is contained in:
commit
0862e8aa46
@ -27,7 +27,7 @@ class FARMReader:
|
||||
self,
|
||||
model_name_or_path,
|
||||
context_window_size=30,
|
||||
no_ans_threshold=-100,
|
||||
no_ans_boost=-100,
|
||||
batch_size=16,
|
||||
use_gpu=True,
|
||||
n_candidates_per_passage=2):
|
||||
@ -40,19 +40,24 @@ class FARMReader:
|
||||
....
|
||||
See https://huggingface.co/models for full list of available models.
|
||||
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
|
||||
:param no_ans_threshold: How much greater the no_answer logit needs to be over the pos_answer in order to be chosen.
|
||||
The higher the value, the more `uncertain` answers are accepted
|
||||
:param no_ans_boost: How much the no_answer logit is boosted/increased.
|
||||
The higher the value, the more likely a "no answer possible" is returned by the model
|
||||
:param batch_size: Number of samples the model receives in one batch for inference
|
||||
:param use_gpu: Whether to use GPU (if available)
|
||||
:param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
|
||||
Note: This is not the number of "final answers" you will receive
|
||||
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
|
||||
# TODO adjust farm. n_cand = 2 returns no answer + highest positive answer
|
||||
# should return no answer + 2 best positive answers
|
||||
# drawback: answers from a single paragraph might be very similar in text and score
|
||||
# we need to have more varied answers (by excluding overlapping answers?)
|
||||
"""
|
||||
|
||||
|
||||
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
|
||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||
self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_threshold
|
||||
self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_boost # TODO adjust naming and concept in FARM
|
||||
self.no_ans_boost = no_ans_boost
|
||||
self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage
|
||||
|
||||
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
|
||||
@ -183,16 +188,21 @@ class FARMReader:
|
||||
}
|
||||
input_dicts.append(cur)
|
||||
|
||||
# get answers from QA model (Top 5 per input paragraph)
|
||||
# get answers from QA model (Default: top 5 per input paragraph)
|
||||
predictions = self.inferencer.inference_from_dicts(
|
||||
dicts=input_dicts, rest_api_schema=True, max_processes=max_processes
|
||||
)
|
||||
|
||||
# assemble answers from all the different paragraphs & format them
|
||||
# assemble answers from all the different paragraphs & format them.
|
||||
# For the "no answer" option, we collect all no_ans_gaps and decide how likely
|
||||
# a no answer is based on all no_ans_gaps values across all documents
|
||||
answers = []
|
||||
no_ans_gaps = []
|
||||
best_score_answer = 0
|
||||
for pred in predictions:
|
||||
for a in pred["predictions"][0]["answers"]:
|
||||
if a["answer"]: #skip "no answer"
|
||||
# skip "no answers" here
|
||||
if a["answer"]:
|
||||
cur = {"answer": a["answer"],
|
||||
"score": a["score"],
|
||||
"probability": float(expit(np.asarray([a["score"]]) / 8)), #just a pseudo prob for now
|
||||
@ -201,14 +211,45 @@ class FARMReader:
|
||||
"offset_end": a["offset_answer_end"] - a["offset_context_start"],
|
||||
"document_id": a["document_id"]}
|
||||
answers.append(cur)
|
||||
no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"])
|
||||
if a["score"] > best_score_answer:
|
||||
best_score_answer = a["score"]
|
||||
|
||||
# adjust no_ans_gaps
|
||||
no_ans_gaps = np.array(no_ans_gaps)
|
||||
no_ans_gaps_adjusted = no_ans_gaps + self.no_ans_boost
|
||||
|
||||
# We want to heuristically rank how likely or unlikely the "no answer" option is.
|
||||
|
||||
# case: all documents return no answer, then all no_ans_gaps are positive
|
||||
if np.sum(no_ans_gaps_adjusted < 0) == 0:
|
||||
# to rank we add the smallest no_ans_gap (a document where an answer would be nearly as likely as the no anser)
|
||||
# to the highest answer score we found
|
||||
no_ans_score = best_score_answer + min(no_ans_gaps_adjusted)
|
||||
# case: documents where answers are preferred over no answer, the no_ans_gap is negative
|
||||
else:
|
||||
# the lowest (highest negative) no_ans_gap would be needed as positive no_ans_boost for the
|
||||
# model to return "no answer" on all documents
|
||||
# we subtract this value from the best answer score to rank our "no answer" option
|
||||
# magically this is the same equation as used for the case above : )
|
||||
no_ans_score = best_score_answer + min(no_ans_gaps_adjusted)
|
||||
|
||||
cur = {"answer": "[computer says no answer is likely]",
|
||||
"score": no_ans_score,
|
||||
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
|
||||
"context": "",
|
||||
"offset_start": 0,
|
||||
"offset_end": 0,
|
||||
"document_id": None}
|
||||
answers.append(cur)
|
||||
|
||||
# sort answers by their `probability` and select top-k
|
||||
answers = sorted(
|
||||
answers, key=lambda k: k["probability"], reverse=True
|
||||
)
|
||||
answers = answers[:top_k]
|
||||
|
||||
result = {"question": question,
|
||||
"answers": answers}
|
||||
"adjust_no_ans_boost": -min(no_ans_gaps_adjusted),
|
||||
"answers": answers}
|
||||
|
||||
return result
|
||||
return result
|
||||
@ -22,8 +22,10 @@ def print_answers(results, details="all"):
|
||||
for key in keys_to_drop:
|
||||
if key in a:
|
||||
del a[key]
|
||||
# print them
|
||||
pp.pprint(answers)
|
||||
|
||||
pp.pprint(answers)
|
||||
else:
|
||||
pp.pprint(results)
|
||||
|
||||
|
||||
def convert_labels_to_squad(labels_file):
|
||||
|
||||
@ -37,7 +37,8 @@ retriever = TfidfRetriever(document_store=document_store)
|
||||
# Reader use more powerful but slower deep learning models
|
||||
# You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models)
|
||||
# here: a medium sized BERT QA model trained via FARM on Squad 2.0
|
||||
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False)
|
||||
# You can adjust the model to return "no answer possible" with the no_ans_boost. Higher values mean the model prefers "no answer possible"
|
||||
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False, no_ans_boost=0)
|
||||
|
||||
# OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers)
|
||||
# reader = TransformersReader(model="distilbert-base-uncased-distilled-squad", tokenizer="distilbert-base-uncased", use_gpu=-1)
|
||||
@ -50,6 +51,7 @@ finder = Finder(reader, retriever)
|
||||
# The higher top_k_retriever, the better (but also the slower) your answers.
|
||||
prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5)
|
||||
|
||||
#prediction = finder.get_answers(question="Who is the daughter of Arya Stark?", top_k_reader=5) # impossible question test
|
||||
#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5)
|
||||
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user