From ef9b99c3cce233cb676130a2abec1d587bbe3032 Mon Sep 17 00:00:00 2001 From: timoeller Date: Fri, 21 Feb 2020 18:27:53 +0100 Subject: [PATCH 1/5] Add no answer handling and sort no answer into positive predictions --- haystack/reader/farm.py | 57 +++++++++++------------- tutorials/Tutorial1_Basic_QA_Pipeline.py | 5 ++- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index df534096f..43c31173c 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -45,19 +45,18 @@ class FARMReader: :param batch_size: Number of samples the model receives in one batch for inference :param use_gpu: Whether to use GPU (if available) :param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). - Note: This is not the number of "final answers" you will receive - (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) - # TODO adjust farm. n_cand = 2 returns no answer + highest positive answer - # should return no answer + 2 best positive answers - # drawback: answers from a single paragraph might be very similar in text and score - # we need to have more varied answers (by excluding overlapping answers?) + Note: - This is not the number of "final answers" you will receive + (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) + - FARM includes no_answer in the sorted list of predictions + + """ self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") self.inferencer.model.prediction_heads[0].context_window_size = context_window_size - self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_boost # TODO adjust naming and concept in FARM - self.no_ans_boost = no_ans_boost + self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost + # TODO adjust terminology: a haystack passage is a FARM document (which gets devided into FARM passages depending on document length and max_seq_len) self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, @@ -188,7 +187,7 @@ class FARMReader: } input_dicts.append(cur) - # get answers from QA model (Default: top 5 per input paragraph) + # get answers from QA model predictions = self.inferencer.inference_from_dicts( dicts=input_dicts, rest_api_schema=True, max_processes=max_processes ) @@ -200,9 +199,10 @@ class FARMReader: no_ans_gaps = [] best_score_answer = 0 for pred in predictions: + positive_found = False for a in pred["predictions"][0]["answers"]: # skip "no answers" here - if a["answer"]: + if(not positive_found and a["answer"]): cur = {"answer": a["answer"], "score": a["score"], "probability": float(expit(np.asarray([a["score"]]) / 8)), #just a pseudo prob for now @@ -214,30 +214,27 @@ class FARMReader: no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"]) if a["score"] > best_score_answer: best_score_answer = a["score"] + positive_found = True - # adjust no_ans_gaps + + # "no answer" scores and positive answers scores are difficult to compare, because + # + a positive answer score is related to one specific document + # - a "no answer" score is related to all input documents + # Thus we compute the "no answer" score relative to the best possible answer and adjust it by + # the most significant difference between scores. + # Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa). + # No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions no_ans_gaps = np.array(no_ans_gaps) - no_ans_gaps_adjusted = no_ans_gaps + self.no_ans_boost + max_no_ans_gap = np.max(no_ans_gaps) + if(np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # all passages "no answer" as top score + no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score + else: # case: at least one passage predicts an answer (positive no_ans_gap) + no_ans_score = best_score_answer - max_no_ans_gap - # We want to heuristically rank how likely or unlikely the "no answer" option is. - - # case: all documents return no answer, then all no_ans_gaps are positive - if np.sum(no_ans_gaps_adjusted < 0) == 0: - # to rank we add the smallest no_ans_gap (a document where an answer would be nearly as likely as the no anser) - # to the highest answer score we found - no_ans_score = best_score_answer + min(no_ans_gaps_adjusted) - # case: documents where answers are preferred over no answer, the no_ans_gap is negative - else: - # the lowest (highest negative) no_ans_gap would be needed as positive no_ans_boost for the - # model to return "no answer" on all documents - # we subtract this value from the best answer score to rank our "no answer" option - # magically this is the same equation as used for the case above : ) - no_ans_score = best_score_answer + min(no_ans_gaps_adjusted) - - cur = {"answer": "[computer says no answer is likely]", + cur = {"answer": None, "score": no_ans_score, "probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now - "context": "", + "context": None, "offset_start": 0, "offset_end": 0, "document_id": None} @@ -249,7 +246,7 @@ class FARMReader: ) answers = answers[:top_k] result = {"question": question, - "adjust_no_ans_boost": -min(no_ans_gaps_adjusted), + "no_ans_gap": max_no_ans_gap, "answers": answers} return result \ No newline at end of file diff --git a/tutorials/Tutorial1_Basic_QA_Pipeline.py b/tutorials/Tutorial1_Basic_QA_Pipeline.py index 1b117e7c7..6efdfc900 100755 --- a/tutorials/Tutorial1_Basic_QA_Pipeline.py +++ b/tutorials/Tutorial1_Basic_QA_Pipeline.py @@ -51,7 +51,10 @@ finder = Finder(reader, retriever) # The higher top_k_retriever, the better (but also the slower) your answers. prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5) -#prediction = finder.get_answers(question="Who is the daughter of Arya Stark?", top_k_reader=5) # impossible question test +# to test impossible questions we need a large QA model, e.g. deepset/bert-large-uncased-whole-word-masking-squad2 +#prediction = finder.get_answers(question="Who is the first daughter of Arya Stark?", top_k_retriever=10, top_k_reader=5) + + #prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5) #prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5) From 0f5b61d20aeff527f4e03d5b38353b2a753b11cb Mon Sep 17 00:00:00 2001 From: timoeller Date: Mon, 24 Feb 2020 12:28:49 +0100 Subject: [PATCH 2/5] Fix typo --- haystack/__init__.py | 2 +- haystack/reader/farm.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/haystack/__init__.py b/haystack/__init__.py index 916fe5259..42dc06718 100644 --- a/haystack/__init__.py +++ b/haystack/__init__.py @@ -46,7 +46,7 @@ class Finder: # 3) Apply reader to get granular answer(s) logger.info(f"Applying the reader now to look for the answer in detail ...") results = self.reader.predict(question=question, - paragrahps=paragraphs, + paragraphs=paragraphs, meta_data_paragraphs=meta_data, top_k=top_k_reader) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 43c31173c..cf0325742 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -145,7 +145,7 @@ class FARMReader: self.inferencer.model.save(directory) self.inferencer.processor.save(directory) - def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1): + def predict(self, question, paragraphs, meta_data_paragraphs=None, top_k=None, max_processes=1): """ Use loaded QA model to find answers for a question in the supplied paragraphs. @@ -175,12 +175,12 @@ class FARMReader: """ if meta_data_paragraphs is None: - meta_data_paragraphs = len(paragrahps) * [None] - assert len(paragrahps) == len(meta_data_paragraphs) + meta_data_paragraphs = len(paragraphs) * [None] + assert len(paragraphs) == len(meta_data_paragraphs) # convert input to FARM format input_dicts = [] - for paragraph, meta_data in zip(paragrahps, meta_data_paragraphs): + for paragraph, meta_data in zip(paragraphs, meta_data_paragraphs): cur = {"text": paragraph, "questions": [question], "document_id": meta_data["document_id"] @@ -202,6 +202,8 @@ class FARMReader: positive_found = False for a in pred["predictions"][0]["answers"]: # skip "no answers" here + # For now we only take one prediction from each passage + # TODO use more predictions per passage when setting n_candidates_per_passage + make FARM predictions more varied if(not positive_found and a["answer"]): cur = {"answer": a["answer"], "score": a["score"], From f681026a5612b879096306c6bcf0d80dc5f2c848 Mon Sep 17 00:00:00 2001 From: timoeller Date: Mon, 24 Feb 2020 16:15:06 +0100 Subject: [PATCH 3/5] Simplify no ans handling, disable no ans + sorting in private function --- haystack/reader/farm.py | 147 ++++++++++++----------- tutorials/Tutorial1_Basic_QA_Pipeline.py | 8 +- 2 files changed, 84 insertions(+), 71 deletions(-) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index cf0325742..2877c5a28 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -23,42 +23,6 @@ class FARMReader: - fine-tune the model on QA data via train() """ - def __init__( - self, - model_name_or_path, - context_window_size=30, - no_ans_boost=-100, - batch_size=16, - use_gpu=True, - n_candidates_per_passage=2): - """ - :param model_name_or_path: directory of a saved model or the name of a public model: - - 'bert-base-cased' - - 'deepset/bert-base-cased-squad2' - - 'deepset/bert-base-cased-squad2' - - 'distilbert-base-uncased-distilled-squad' - .... - See https://huggingface.co/models for full list of available models. - :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. - :param no_ans_boost: How much the no_answer logit is boosted/increased. - The higher the value, the more likely a "no answer possible" is returned by the model - :param batch_size: Number of samples the model receives in one batch for inference - :param use_gpu: Whether to use GPU (if available) - :param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). - Note: - This is not the number of "final answers" you will receive - (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) - - FARM includes no_answer in the sorted list of predictions - - - """ - - - self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") - self.inferencer.model.prediction_heads[0].context_window_size = context_window_size - self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost - # TODO adjust terminology: a haystack passage is a FARM document (which gets devided into FARM passages depending on document length and max_seq_len) - self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage - def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5, max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): @@ -140,6 +104,49 @@ class FARMReader: self.inferencer.model = trainer.train() self.save(save_dir) + def __init__( + self, + model_name_or_path, + context_window_size=30, + batch_size=50, + use_gpu=True, + no_ans_boost=None, + n_candidates_per_paragraph=1): + """ + :param model_name_or_path: directory of a saved model or the name of a public model: + - 'bert-base-cased' + - 'deepset/bert-base-cased-squad2' + - 'deepset/bert-base-cased-squad2' + - 'distilbert-base-uncased-distilled-squad' + .... + See https://huggingface.co/models for full list of available models. + :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. + :param batch_size: Number of samples the model receives in one batch for inference + Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used. + :param use_gpu: Whether to use GPU (if available) + :param no_ans_boost: How much the no_answer logit is boosted/increased. + Possible values: None (default) = disable returning "no answer" predictions + Negative = lower chance of "no answer" being predicted + Positive = increase chance of "no answer" + :param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). + Note: - This is not the number of "final answers" you will receive + (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) + - FARM includes no_answer in the sorted list of predictions + + + """ + + if no_ans_boost is None: + no_ans_boost = 0 + self.return_no_answers = False + else: + self.return_no_answers = True + self.n_candidates_per_paragraph = n_candidates_per_paragraph + self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") + self.inferencer.model.prediction_heads[0].context_window_size = context_window_size + self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost + self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer + def save(self, directory): logger.info(f"Saving reader model to {directory}") self.inferencer.model.save(directory) @@ -159,7 +166,7 @@ class FARMReader: 'offset_answer_end': 154, 'probability': 0.9787139466668613, 'score': None, - 'document_id': None + 'document_id': '1337' }, ... ] @@ -199,12 +206,11 @@ class FARMReader: no_ans_gaps = [] best_score_answer = 0 for pred in predictions: - positive_found = False + answers_per_paragraph = [] + no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"]) for a in pred["predictions"][0]["answers"]: # skip "no answers" here - # For now we only take one prediction from each passage - # TODO use more predictions per passage when setting n_candidates_per_passage + make FARM predictions more varied - if(not positive_found and a["answer"]): + if a["answer"]: cur = {"answer": a["answer"], "score": a["score"], "probability": float(expit(np.asarray([a["score"]]) / 8)), #just a pseudo prob for now @@ -212,35 +218,17 @@ class FARMReader: "offset_start": a["offset_answer_start"] - a["offset_context_start"], "offset_end": a["offset_answer_end"] - a["offset_context_start"], "document_id": a["document_id"]} - answers.append(cur) - no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"]) + answers_per_paragraph.append(cur) + if a["score"] > best_score_answer: best_score_answer = a["score"] - positive_found = True + answers += answers_per_paragraph[:self.n_candidates_per_paragraph] - # "no answer" scores and positive answers scores are difficult to compare, because - # + a positive answer score is related to one specific document - # - a "no answer" score is related to all input documents - # Thus we compute the "no answer" score relative to the best possible answer and adjust it by - # the most significant difference between scores. - # Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa). - # No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions - no_ans_gaps = np.array(no_ans_gaps) - max_no_ans_gap = np.max(no_ans_gaps) - if(np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # all passages "no answer" as top score - no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score - else: # case: at least one passage predicts an answer (positive no_ans_gap) - no_ans_score = best_score_answer - max_no_ans_gap - - cur = {"answer": None, - "score": no_ans_score, - "probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now - "context": None, - "offset_start": 0, - "offset_end": 0, - "document_id": None} - answers.append(cur) + # Calculate the score for predicting "no answer", relative to our best positive answer score + no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps,best_score_answer) + if self.return_no_answers: + answers.append(no_ans_prediction) # sort answers by their `probability` and select top-k answers = sorted( @@ -251,4 +239,29 @@ class FARMReader: "no_ans_gap": max_no_ans_gap, "answers": answers} - return result \ No newline at end of file + return result + + @staticmethod + def _calc_no_answer(no_ans_gaps,best_score_answer): + # "no answer" scores and positive answers scores are difficult to compare, because + # + a positive answer score is related to one specific document + # - a "no answer" score is related to all input documents + # Thus we compute the "no answer" score relative to the best possible answer and adjust it by + # the most significant difference between scores. + # Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa). + # No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions + no_ans_gaps = np.array(no_ans_gaps) + max_no_ans_gap = np.max(no_ans_gaps) + if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # all passages "no answer" as top score + no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score + else: # case: at least one passage predicts an answer (positive no_ans_gap) + no_ans_score = best_score_answer - max_no_ans_gap + + no_ans_prediction = {"answer": None, + "score": no_ans_score, + "probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now + "context": None, + "offset_start": 0, + "offset_end": 0, + "document_id": None} + return no_ans_prediction, max_no_ans_gap \ No newline at end of file diff --git a/tutorials/Tutorial1_Basic_QA_Pipeline.py b/tutorials/Tutorial1_Basic_QA_Pipeline.py index 6efdfc900..b8f1eaa5e 100755 --- a/tutorials/Tutorial1_Basic_QA_Pipeline.py +++ b/tutorials/Tutorial1_Basic_QA_Pipeline.py @@ -35,10 +35,10 @@ retriever = TfidfRetriever(document_store=document_store) # A reader scans the text chunks in detail and extracts the k best answers # Reader use more powerful but slower deep learning models -# You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models) +# You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models) # here: a medium sized BERT QA model trained via FARM on Squad 2.0 # You can adjust the model to return "no answer possible" with the no_ans_boost. Higher values mean the model prefers "no answer possible" -reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False, no_ans_boost=0) +reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False) # OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers) # reader = TransformersReader(model="distilbert-base-uncased-distilled-squad", tokenizer="distilbert-base-uncased", use_gpu=-1) @@ -52,8 +52,8 @@ finder = Finder(reader, retriever) prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5) # to test impossible questions we need a large QA model, e.g. deepset/bert-large-uncased-whole-word-masking-squad2 -#prediction = finder.get_answers(question="Who is the first daughter of Arya Stark?", top_k_retriever=10, top_k_reader=5) - +# and we need to enable returning "no answer possible" by setting no_ans_boost=X in FARMReader +# prediction = finder.get_answers(question="Who is the first daughter of Arya Stark?", top_k_retriever=10, top_k_reader=5) #prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5) #prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5) From d15448f60a8f11a25a6b841b991c04c20f0c6ad7 Mon Sep 17 00:00:00 2001 From: timoeller Date: Mon, 24 Feb 2020 16:23:13 +0100 Subject: [PATCH 4/5] Redo changing of inti --- haystack/reader/farm.py | 86 ++++++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 2877c5a28..55cbdbf2a 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -23,6 +23,49 @@ class FARMReader: - fine-tune the model on QA data via train() """ + def __init__( + self, + model_name_or_path, + context_window_size=30, + batch_size=50, + use_gpu=True, + no_ans_boost=None, + n_candidates_per_paragraph=1): + """ + :param model_name_or_path: directory of a saved model or the name of a public model: + - 'bert-base-cased' + - 'deepset/bert-base-cased-squad2' + - 'deepset/bert-base-cased-squad2' + - 'distilbert-base-uncased-distilled-squad' + .... + See https://huggingface.co/models for full list of available models. + :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. + :param batch_size: Number of samples the model receives in one batch for inference + Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used. + :param use_gpu: Whether to use GPU (if available) + :param no_ans_boost: How much the no_answer logit is boosted/increased. + Possible values: None (default) = disable returning "no answer" predictions + Negative = lower chance of "no answer" being predicted + Positive = increase chance of "no answer" + :param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). + Note: - This is not the number of "final answers" you will receive + (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) + - FARM includes no_answer in the sorted list of predictions + + + """ + + if no_ans_boost is None: + no_ans_boost = 0 + self.return_no_answers = False + else: + self.return_no_answers = True + self.n_candidates_per_paragraph = n_candidates_per_paragraph + self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") + self.inferencer.model.prediction_heads[0].context_window_size = context_window_size + self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost + self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer + def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5, max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): @@ -104,49 +147,6 @@ class FARMReader: self.inferencer.model = trainer.train() self.save(save_dir) - def __init__( - self, - model_name_or_path, - context_window_size=30, - batch_size=50, - use_gpu=True, - no_ans_boost=None, - n_candidates_per_paragraph=1): - """ - :param model_name_or_path: directory of a saved model or the name of a public model: - - 'bert-base-cased' - - 'deepset/bert-base-cased-squad2' - - 'deepset/bert-base-cased-squad2' - - 'distilbert-base-uncased-distilled-squad' - .... - See https://huggingface.co/models for full list of available models. - :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. - :param batch_size: Number of samples the model receives in one batch for inference - Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used. - :param use_gpu: Whether to use GPU (if available) - :param no_ans_boost: How much the no_answer logit is boosted/increased. - Possible values: None (default) = disable returning "no answer" predictions - Negative = lower chance of "no answer" being predicted - Positive = increase chance of "no answer" - :param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). - Note: - This is not the number of "final answers" you will receive - (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) - - FARM includes no_answer in the sorted list of predictions - - - """ - - if no_ans_boost is None: - no_ans_boost = 0 - self.return_no_answers = False - else: - self.return_no_answers = True - self.n_candidates_per_paragraph = n_candidates_per_paragraph - self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") - self.inferencer.model.prediction_heads[0].context_window_size = context_window_size - self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost - self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer - def save(self, directory): logger.info(f"Saving reader model to {directory}") self.inferencer.model.save(directory) From 96a0847b32f6284c23b6254b04d1b4e497526355 Mon Sep 17 00:00:00 2001 From: timoeller Date: Mon, 24 Feb 2020 16:26:59 +0100 Subject: [PATCH 5/5] Add comment --- haystack/reader/farm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 55cbdbf2a..da430a537 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -222,7 +222,7 @@ class FARMReader: if a["score"] > best_score_answer: best_score_answer = a["score"] - + # only take n best candidates. Answers coming back from FARM are sorted with decreasing relevance. answers += answers_per_paragraph[:self.n_candidates_per_paragraph] # Calculate the score for predicting "no answer", relative to our best positive answer score