diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 2877c5a28..55cbdbf2a 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -23,6 +23,49 @@ class FARMReader: - fine-tune the model on QA data via train() """ + def __init__( + self, + model_name_or_path, + context_window_size=30, + batch_size=50, + use_gpu=True, + no_ans_boost=None, + n_candidates_per_paragraph=1): + """ + :param model_name_or_path: directory of a saved model or the name of a public model: + - 'bert-base-cased' + - 'deepset/bert-base-cased-squad2' + - 'deepset/bert-base-cased-squad2' + - 'distilbert-base-uncased-distilled-squad' + .... + See https://huggingface.co/models for full list of available models. + :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. + :param batch_size: Number of samples the model receives in one batch for inference + Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used. + :param use_gpu: Whether to use GPU (if available) + :param no_ans_boost: How much the no_answer logit is boosted/increased. + Possible values: None (default) = disable returning "no answer" predictions + Negative = lower chance of "no answer" being predicted + Positive = increase chance of "no answer" + :param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). + Note: - This is not the number of "final answers" you will receive + (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) + - FARM includes no_answer in the sorted list of predictions + + + """ + + if no_ans_boost is None: + no_ans_boost = 0 + self.return_no_answers = False + else: + self.return_no_answers = True + self.n_candidates_per_paragraph = n_candidates_per_paragraph + self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") + self.inferencer.model.prediction_heads[0].context_window_size = context_window_size + self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost + self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer + def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5, max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): @@ -104,49 +147,6 @@ class FARMReader: self.inferencer.model = trainer.train() self.save(save_dir) - def __init__( - self, - model_name_or_path, - context_window_size=30, - batch_size=50, - use_gpu=True, - no_ans_boost=None, - n_candidates_per_paragraph=1): - """ - :param model_name_or_path: directory of a saved model or the name of a public model: - - 'bert-base-cased' - - 'deepset/bert-base-cased-squad2' - - 'deepset/bert-base-cased-squad2' - - 'distilbert-base-uncased-distilled-squad' - .... - See https://huggingface.co/models for full list of available models. - :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. - :param batch_size: Number of samples the model receives in one batch for inference - Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used. - :param use_gpu: Whether to use GPU (if available) - :param no_ans_boost: How much the no_answer logit is boosted/increased. - Possible values: None (default) = disable returning "no answer" predictions - Negative = lower chance of "no answer" being predicted - Positive = increase chance of "no answer" - :param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). - Note: - This is not the number of "final answers" you will receive - (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) - - FARM includes no_answer in the sorted list of predictions - - - """ - - if no_ans_boost is None: - no_ans_boost = 0 - self.return_no_answers = False - else: - self.return_no_answers = True - self.n_candidates_per_paragraph = n_candidates_per_paragraph - self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") - self.inferencer.model.prediction_heads[0].context_window_size = context_window_size - self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost - self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer - def save(self, directory): logger.info(f"Saving reader model to {directory}") self.inferencer.model.save(directory)