FARMReader.train() uses values from class

This commit is contained in:
Branden Chan 2020-03-24 17:42:50 +01:00
parent 9897c40c41
commit d41dcca813

View File

@ -87,10 +87,12 @@ class FARMReader:
except:
logger.warning("Could not set `top_k_per_sample` in FARM. Please update FARM version.")
self.max_processes = max_processes
self.max_seq_len = max_seq_len
self.use_gpu = use_gpu
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
use_gpu=None, batch_size=10, n_epochs=2, learning_rate=1e-5,
max_seq_len=None, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
"""
Fine-tune a model on a QA dataset. Options:
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
@ -120,6 +122,14 @@ class FARMReader:
dev_split = None
set_all_seeds(seed=42)
# For these variables, by default, we use the value set when initializing the FARMReader.
# This can also be manually set when train() is called if you want a different value at train vs inference
if use_gpu is None:
use_gpu = self.use_gpu
if max_seq_len is None:
max_seq_len = self.max_seq_len
device, n_gpu = initialize_device_settings(use_cuda=use_gpu)
if not save_dir: