Add checkpointing for reader.train() to allow stopping + resuming training (#1554)

* adding create checkpoint feature for train function in farm reader

* added arguments for create_or_load_checkpoint function

* accessing class method inside Trainer class

* added default value for checkpoint_root_dir and checkpoint_every, checkpoints_to_keep as arguments for reader.train()

* change in default value for checkpoint_root_dir and checkpoint_every

* update docstring and add Path conversion

Co-authored-by: girish.koushik <girish.koushik@diatoz.com>
Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
Girish A Koushik 2021-10-19 16:06:32 +05:30 committed by GitHub
parent 575e64333c
commit 5a6285f23f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -163,6 +163,9 @@ class FARMReader(BaseReader):
save_dir: Optional[str] = None, save_dir: Optional[str] = None,
num_processes: Optional[int] = None, num_processes: Optional[int] = None,
use_amp: str = None, use_amp: str = None,
checkpoint_root_dir: Path = Path("model_checkpoints"),
checkpoint_every: Optional[int] = None,
checkpoints_to_keep: int = 3,
): ):
""" """
Fine-tune a model on a QA dataset. Options: Fine-tune a model on a QA dataset. Options:
@ -170,6 +173,9 @@ class FARMReader(BaseReader):
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data) - Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
- Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool) - Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
Checkpoints can be stored via setting `checkpoint_every` to a custom number of steps.
If any checkpoints are stored, a subsequent run of train() will resume training from the latest available checkpoint.
:param data_dir: Path to directory containing your training data in SQuAD style :param data_dir: Path to directory containing your training data in SQuAD style
:param train_filename: Filename of training data :param train_filename: Filename of training data
:param dev_filename: Filename of dev / eval data :param dev_filename: Filename of dev / eval data
@ -197,6 +203,10 @@ class FARMReader(BaseReader):
"O2" (Almost FP16) "O2" (Almost FP16)
"O3" (Pure FP16). "O3" (Pure FP16).
See details on: https://nvidia.github.io/apex/amp.html See details on: https://nvidia.github.io/apex/amp.html
:param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
:param checkpoint_every: save a train checkpoint after this many steps of training.
:param checkpoints_to_keep: maximum number of train checkpoints to save.
:return: None :return: None
""" """
@ -251,7 +261,7 @@ class FARMReader(BaseReader):
use_amp=use_amp, use_amp=use_amp,
) )
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
trainer = Trainer( trainer = Trainer.create_or_load_checkpoint(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
data_silo=data_silo, data_silo=data_silo,
@ -261,10 +271,12 @@ class FARMReader(BaseReader):
evaluate_every=evaluate_every, evaluate_every=evaluate_every,
device=device, device=device,
use_amp=use_amp, use_amp=use_amp,
disable_tqdm=not self.progress_bar disable_tqdm=not self.progress_bar,
checkpoint_root_dir=Path(checkpoint_root_dir),
checkpoint_every=checkpoint_every,
checkpoints_to_keep=checkpoints_to_keep,
) )
# 5. Let it grow! # 5. Let it grow!
self.inferencer.model = trainer.train() self.inferencer.model = trainer.train()
self.save(Path(save_dir)) self.save(Path(save_dir))