Merge pull request #47 from deepset-ai/train_params

FARMReader.train() now takes default values from FARMReader
This commit is contained in:
Branden Chan 2020-03-24 18:29:26 +01:00 committed by GitHub
commit 5932aa01c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 3 deletions

View File

@ -87,10 +87,12 @@ class FARMReader:
except:
logger.warning("Could not set `top_k_per_sample` in FARM. Please update FARM version.")
self.max_processes = max_processes
self.max_seq_len = max_seq_len
self.use_gpu = use_gpu
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
use_gpu=None, batch_size=10, n_epochs=2, learning_rate=1e-5,
max_seq_len=None, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
"""
Fine-tune a model on a QA dataset. Options:
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
@ -120,6 +122,14 @@ class FARMReader:
dev_split = None
set_all_seeds(seed=42)
# For these variables, by default, we use the value set when initializing the FARMReader.
# These can also be manually set when train() is called if you want a different value at train vs inference
if use_gpu is None:
use_gpu = self.use_gpu
if max_seq_len is None:
max_seq_len = self.max_seq_len
device, n_gpu = initialize_device_settings(use_cuda=use_gpu)
if not save_dir:

View File

@ -13,7 +13,7 @@ reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad"
# and fine-tune it on your own custom dataset (should be in SQuAD like format)
train_data = "PATH/TO_YOUR/TRAIN_DATA"
reader.train(data_dir=train_data, train_filename="train.json", use_gpu=False, n_epochs=1)
reader.train(data_dir=train_data, train_filename="train.json", n_epochs=1)
#### Use it (same as in Tutorial 1) #############