mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-07 15:23:31 +00:00
Merge pull request #26 from deepset-ai/no_answer
Add no answer handling
This commit is contained in:
commit
c0910c82c5
@ -46,7 +46,7 @@ class Finder:
|
|||||||
# 3) Apply reader to get granular answer(s)
|
# 3) Apply reader to get granular answer(s)
|
||||||
logger.info(f"Applying the reader now to look for the answer in detail ...")
|
logger.info(f"Applying the reader now to look for the answer in detail ...")
|
||||||
results = self.reader.predict(question=question,
|
results = self.reader.predict(question=question,
|
||||||
paragrahps=paragraphs,
|
paragraphs=paragraphs,
|
||||||
meta_data_paragraphs=meta_data,
|
meta_data_paragraphs=meta_data,
|
||||||
top_k=top_k_reader)
|
top_k=top_k_reader)
|
||||||
|
|
||||||
|
@ -27,10 +27,10 @@ class FARMReader:
|
|||||||
self,
|
self,
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
context_window_size=30,
|
context_window_size=30,
|
||||||
no_ans_boost=-100,
|
batch_size=50,
|
||||||
batch_size=16,
|
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
n_candidates_per_passage=2):
|
no_ans_boost=None,
|
||||||
|
n_candidates_per_paragraph=1):
|
||||||
"""
|
"""
|
||||||
:param model_name_or_path: directory of a saved model or the name of a public model:
|
:param model_name_or_path: directory of a saved model or the name of a public model:
|
||||||
- 'bert-base-cased'
|
- 'bert-base-cased'
|
||||||
@ -40,25 +40,31 @@ class FARMReader:
|
|||||||
....
|
....
|
||||||
See https://huggingface.co/models for full list of available models.
|
See https://huggingface.co/models for full list of available models.
|
||||||
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
|
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
|
||||||
:param no_ans_boost: How much the no_answer logit is boosted/increased.
|
|
||||||
The higher the value, the more likely a "no answer possible" is returned by the model
|
|
||||||
:param batch_size: Number of samples the model receives in one batch for inference
|
:param batch_size: Number of samples the model receives in one batch for inference
|
||||||
|
Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used.
|
||||||
:param use_gpu: Whether to use GPU (if available)
|
:param use_gpu: Whether to use GPU (if available)
|
||||||
:param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
|
:param no_ans_boost: How much the no_answer logit is boosted/increased.
|
||||||
Note: This is not the number of "final answers" you will receive
|
Possible values: None (default) = disable returning "no answer" predictions
|
||||||
|
Negative = lower chance of "no answer" being predicted
|
||||||
|
Positive = increase chance of "no answer"
|
||||||
|
:param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
|
||||||
|
Note: - This is not the number of "final answers" you will receive
|
||||||
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
|
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
|
||||||
# TODO adjust farm. n_cand = 2 returns no answer + highest positive answer
|
- FARM includes no_answer in the sorted list of predictions
|
||||||
# should return no answer + 2 best positive answers
|
|
||||||
# drawback: answers from a single paragraph might be very similar in text and score
|
|
||||||
# we need to have more varied answers (by excluding overlapping answers?)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if no_ans_boost is None:
|
||||||
|
no_ans_boost = 0
|
||||||
|
self.return_no_answers = False
|
||||||
|
else:
|
||||||
|
self.return_no_answers = True
|
||||||
|
self.n_candidates_per_paragraph = n_candidates_per_paragraph
|
||||||
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
|
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
|
||||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||||
self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_boost # TODO adjust naming and concept in FARM
|
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
||||||
self.no_ans_boost = no_ans_boost
|
self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer
|
||||||
self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage
|
|
||||||
|
|
||||||
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
|
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
|
||||||
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
|
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
|
||||||
@ -146,7 +152,7 @@ class FARMReader:
|
|||||||
self.inferencer.model.save(directory)
|
self.inferencer.model.save(directory)
|
||||||
self.inferencer.processor.save(directory)
|
self.inferencer.processor.save(directory)
|
||||||
|
|
||||||
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1):
|
def predict(self, question, paragraphs, meta_data_paragraphs=None, top_k=None, max_processes=1):
|
||||||
"""
|
"""
|
||||||
Use loaded QA model to find answers for a question in the supplied paragraphs.
|
Use loaded QA model to find answers for a question in the supplied paragraphs.
|
||||||
|
|
||||||
@ -160,7 +166,7 @@ class FARMReader:
|
|||||||
'offset_answer_end': 154,
|
'offset_answer_end': 154,
|
||||||
'probability': 0.9787139466668613,
|
'probability': 0.9787139466668613,
|
||||||
'score': None,
|
'score': None,
|
||||||
'document_id': None
|
'document_id': '1337'
|
||||||
},
|
},
|
||||||
...
|
...
|
||||||
]
|
]
|
||||||
@ -176,19 +182,19 @@ class FARMReader:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if meta_data_paragraphs is None:
|
if meta_data_paragraphs is None:
|
||||||
meta_data_paragraphs = len(paragrahps) * [None]
|
meta_data_paragraphs = len(paragraphs) * [None]
|
||||||
assert len(paragrahps) == len(meta_data_paragraphs)
|
assert len(paragraphs) == len(meta_data_paragraphs)
|
||||||
|
|
||||||
# convert input to FARM format
|
# convert input to FARM format
|
||||||
input_dicts = []
|
input_dicts = []
|
||||||
for paragraph, meta_data in zip(paragrahps, meta_data_paragraphs):
|
for paragraph, meta_data in zip(paragraphs, meta_data_paragraphs):
|
||||||
cur = {"text": paragraph,
|
cur = {"text": paragraph,
|
||||||
"questions": [question],
|
"questions": [question],
|
||||||
"document_id": meta_data["document_id"]
|
"document_id": meta_data["document_id"]
|
||||||
}
|
}
|
||||||
input_dicts.append(cur)
|
input_dicts.append(cur)
|
||||||
|
|
||||||
# get answers from QA model (Default: top 5 per input paragraph)
|
# get answers from QA model
|
||||||
predictions = self.inferencer.inference_from_dicts(
|
predictions = self.inferencer.inference_from_dicts(
|
||||||
dicts=input_dicts, rest_api_schema=True, max_processes=max_processes
|
dicts=input_dicts, rest_api_schema=True, max_processes=max_processes
|
||||||
)
|
)
|
||||||
@ -200,6 +206,8 @@ class FARMReader:
|
|||||||
no_ans_gaps = []
|
no_ans_gaps = []
|
||||||
best_score_answer = 0
|
best_score_answer = 0
|
||||||
for pred in predictions:
|
for pred in predictions:
|
||||||
|
answers_per_paragraph = []
|
||||||
|
no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"])
|
||||||
for a in pred["predictions"][0]["answers"]:
|
for a in pred["predictions"][0]["answers"]:
|
||||||
# skip "no answers" here
|
# skip "no answers" here
|
||||||
if a["answer"]:
|
if a["answer"]:
|
||||||
@ -210,38 +218,17 @@ class FARMReader:
|
|||||||
"offset_start": a["offset_answer_start"] - a["offset_context_start"],
|
"offset_start": a["offset_answer_start"] - a["offset_context_start"],
|
||||||
"offset_end": a["offset_answer_end"] - a["offset_context_start"],
|
"offset_end": a["offset_answer_end"] - a["offset_context_start"],
|
||||||
"document_id": a["document_id"]}
|
"document_id": a["document_id"]}
|
||||||
answers.append(cur)
|
answers_per_paragraph.append(cur)
|
||||||
no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"])
|
|
||||||
if a["score"] > best_score_answer:
|
if a["score"] > best_score_answer:
|
||||||
best_score_answer = a["score"]
|
best_score_answer = a["score"]
|
||||||
|
# only take n best candidates. Answers coming back from FARM are sorted with decreasing relevance.
|
||||||
|
answers += answers_per_paragraph[:self.n_candidates_per_paragraph]
|
||||||
|
|
||||||
# adjust no_ans_gaps
|
# Calculate the score for predicting "no answer", relative to our best positive answer score
|
||||||
no_ans_gaps = np.array(no_ans_gaps)
|
no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps,best_score_answer)
|
||||||
no_ans_gaps_adjusted = no_ans_gaps + self.no_ans_boost
|
if self.return_no_answers:
|
||||||
|
answers.append(no_ans_prediction)
|
||||||
# We want to heuristically rank how likely or unlikely the "no answer" option is.
|
|
||||||
|
|
||||||
# case: all documents return no answer, then all no_ans_gaps are positive
|
|
||||||
if np.sum(no_ans_gaps_adjusted < 0) == 0:
|
|
||||||
# to rank we add the smallest no_ans_gap (a document where an answer would be nearly as likely as the no anser)
|
|
||||||
# to the highest answer score we found
|
|
||||||
no_ans_score = best_score_answer + min(no_ans_gaps_adjusted)
|
|
||||||
# case: documents where answers are preferred over no answer, the no_ans_gap is negative
|
|
||||||
else:
|
|
||||||
# the lowest (highest negative) no_ans_gap would be needed as positive no_ans_boost for the
|
|
||||||
# model to return "no answer" on all documents
|
|
||||||
# we subtract this value from the best answer score to rank our "no answer" option
|
|
||||||
# magically this is the same equation as used for the case above : )
|
|
||||||
no_ans_score = best_score_answer + min(no_ans_gaps_adjusted)
|
|
||||||
|
|
||||||
cur = {"answer": "[computer says no answer is likely]",
|
|
||||||
"score": no_ans_score,
|
|
||||||
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
|
|
||||||
"context": "",
|
|
||||||
"offset_start": 0,
|
|
||||||
"offset_end": 0,
|
|
||||||
"document_id": None}
|
|
||||||
answers.append(cur)
|
|
||||||
|
|
||||||
# sort answers by their `probability` and select top-k
|
# sort answers by their `probability` and select top-k
|
||||||
answers = sorted(
|
answers = sorted(
|
||||||
@ -249,7 +236,32 @@ class FARMReader:
|
|||||||
)
|
)
|
||||||
answers = answers[:top_k]
|
answers = answers[:top_k]
|
||||||
result = {"question": question,
|
result = {"question": question,
|
||||||
"adjust_no_ans_boost": -min(no_ans_gaps_adjusted),
|
"no_ans_gap": max_no_ans_gap,
|
||||||
"answers": answers}
|
"answers": answers}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calc_no_answer(no_ans_gaps,best_score_answer):
|
||||||
|
# "no answer" scores and positive answers scores are difficult to compare, because
|
||||||
|
# + a positive answer score is related to one specific document
|
||||||
|
# - a "no answer" score is related to all input documents
|
||||||
|
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
|
||||||
|
# the most significant difference between scores.
|
||||||
|
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
|
||||||
|
# No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions
|
||||||
|
no_ans_gaps = np.array(no_ans_gaps)
|
||||||
|
max_no_ans_gap = np.max(no_ans_gaps)
|
||||||
|
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # all passages "no answer" as top score
|
||||||
|
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
|
||||||
|
else: # case: at least one passage predicts an answer (positive no_ans_gap)
|
||||||
|
no_ans_score = best_score_answer - max_no_ans_gap
|
||||||
|
|
||||||
|
no_ans_prediction = {"answer": None,
|
||||||
|
"score": no_ans_score,
|
||||||
|
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
|
||||||
|
"context": None,
|
||||||
|
"offset_start": 0,
|
||||||
|
"offset_end": 0,
|
||||||
|
"document_id": None}
|
||||||
|
return no_ans_prediction, max_no_ans_gap
|
@ -38,7 +38,7 @@ retriever = TfidfRetriever(document_store=document_store)
|
|||||||
# You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models)
|
# You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models)
|
||||||
# here: a medium sized BERT QA model trained via FARM on Squad 2.0
|
# here: a medium sized BERT QA model trained via FARM on Squad 2.0
|
||||||
# You can adjust the model to return "no answer possible" with the no_ans_boost. Higher values mean the model prefers "no answer possible"
|
# You can adjust the model to return "no answer possible" with the no_ans_boost. Higher values mean the model prefers "no answer possible"
|
||||||
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False, no_ans_boost=0)
|
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False)
|
||||||
|
|
||||||
# OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers)
|
# OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers)
|
||||||
# reader = TransformersReader(model="distilbert-base-uncased-distilled-squad", tokenizer="distilbert-base-uncased", use_gpu=-1)
|
# reader = TransformersReader(model="distilbert-base-uncased-distilled-squad", tokenizer="distilbert-base-uncased", use_gpu=-1)
|
||||||
@ -51,7 +51,10 @@ finder = Finder(reader, retriever)
|
|||||||
# The higher top_k_retriever, the better (but also the slower) your answers.
|
# The higher top_k_retriever, the better (but also the slower) your answers.
|
||||||
prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5)
|
prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5)
|
||||||
|
|
||||||
#prediction = finder.get_answers(question="Who is the daughter of Arya Stark?", top_k_reader=5) # impossible question test
|
# to test impossible questions we need a large QA model, e.g. deepset/bert-large-uncased-whole-word-masking-squad2
|
||||||
|
# and we need to enable returning "no answer possible" by setting no_ans_boost=X in FARMReader
|
||||||
|
# prediction = finder.get_answers(question="Who is the first daughter of Arya Stark?", top_k_retriever=10, top_k_reader=5)
|
||||||
|
|
||||||
#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5)
|
#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5)
|
||||||
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)
|
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user