From 5a6285f23fda7ead23e0993b4985db2bd2577fd4 Mon Sep 17 00:00:00 2001 From: Girish A Koushik <31188269+gak97@users.noreply.github.com> Date: Tue, 19 Oct 2021 16:06:32 +0530 Subject: [PATCH] Add checkpointing for reader.train() to allow stopping + resuming training (#1554) * adding create checkpoint feature for train function in farm reader * added arguments for create_or_load_checkpoint function * accessing class method inside Trainer class * added default value for checkpoint_root_dir and checkpoint_every, checkpoints_to_keep as arguments for reader.train() * change in default value for checkpoint_root_dir and checkpoint_every * update docstring and add Path conversion Co-authored-by: girish.koushik Co-authored-by: Malte Pietsch --- haystack/reader/farm.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index fe6d8abfd..c621fa1f3 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -163,13 +163,19 @@ class FARMReader(BaseReader): save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, + checkpoint_root_dir: Path = Path("model_checkpoints"), + checkpoint_every: Optional[int] = None, + checkpoints_to_keep: int = 3, ): """ Fine-tune a model on a QA dataset. Options: - Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data) - Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool) - + + Checkpoints can be stored via setting `checkpoint_every` to a custom number of steps. + If any checkpoints are stored, a subsequent run of train() will resume training from the latest available checkpoint. + :param data_dir: Path to directory containing your training data in SQuAD style :param train_filename: Filename of training data :param dev_filename: Filename of dev / eval data @@ -197,6 +203,10 @@ class FARMReader(BaseReader): "O2" (Almost FP16) "O3" (Pure FP16). See details on: https://nvidia.github.io/apex/amp.html + :param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual + checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created. + :param checkpoint_every: save a train checkpoint after this many steps of training. + :param checkpoints_to_keep: maximum number of train checkpoints to save. :return: None """ @@ -251,7 +261,7 @@ class FARMReader(BaseReader): use_amp=use_amp, ) # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time - trainer = Trainer( + trainer = Trainer.create_or_load_checkpoint( model=model, optimizer=optimizer, data_silo=data_silo, @@ -261,10 +271,12 @@ class FARMReader(BaseReader): evaluate_every=evaluate_every, device=device, use_amp=use_amp, - disable_tqdm=not self.progress_bar + disable_tqdm=not self.progress_bar, + checkpoint_root_dir=Path(checkpoint_root_dir), + checkpoint_every=checkpoint_every, + checkpoints_to_keep=checkpoints_to_keep, ) - # 5. Let it grow! self.inferencer.model = trainer.train() self.save(Path(save_dir))