diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 2a1731879..03be48b1a 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -87,10 +87,12 @@ class FARMReader: except: logger.warning("Could not set `top_k_per_sample` in FARM. Please update FARM version.") self.max_processes = max_processes + self.max_seq_len = max_seq_len + self.use_gpu = use_gpu def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, - use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5, - max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): + use_gpu=None, batch_size=10, n_epochs=2, learning_rate=1e-5, + max_seq_len=None, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): """ Fine-tune a model on a QA dataset. Options: - Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data) @@ -120,6 +122,14 @@ class FARMReader: dev_split = None set_all_seeds(seed=42) + + # For these variables, by default, we use the value set when initializing the FARMReader. + # This can also be manually set when train() is called if you want a different value at train vs inference + if use_gpu is None: + use_gpu = self.use_gpu + if max_seq_len is None: + max_seq_len = self.max_seq_len + device, n_gpu = initialize_device_settings(use_cuda=use_gpu) if not save_dir: