mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-18 12:43:28 +00:00
Add checkpointing for reader.train() to allow stopping + resuming training (#1554)
* adding create checkpoint feature for train function in farm reader * added arguments for create_or_load_checkpoint function * accessing class method inside Trainer class * added default value for checkpoint_root_dir and checkpoint_every, checkpoints_to_keep as arguments for reader.train() * change in default value for checkpoint_root_dir and checkpoint_every * update docstring and add Path conversion Co-authored-by: girish.koushik <girish.koushik@diatoz.com> Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
575e64333c
commit
5a6285f23f
@ -163,13 +163,19 @@ class FARMReader(BaseReader):
|
|||||||
save_dir: Optional[str] = None,
|
save_dir: Optional[str] = None,
|
||||||
num_processes: Optional[int] = None,
|
num_processes: Optional[int] = None,
|
||||||
use_amp: str = None,
|
use_amp: str = None,
|
||||||
|
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||||
|
checkpoint_every: Optional[int] = None,
|
||||||
|
checkpoints_to_keep: int = 3,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Fine-tune a model on a QA dataset. Options:
|
Fine-tune a model on a QA dataset. Options:
|
||||||
|
|
||||||
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
|
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
|
||||||
- Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
|
- Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
|
||||||
|
|
||||||
|
Checkpoints can be stored via setting `checkpoint_every` to a custom number of steps.
|
||||||
|
If any checkpoints are stored, a subsequent run of train() will resume training from the latest available checkpoint.
|
||||||
|
|
||||||
:param data_dir: Path to directory containing your training data in SQuAD style
|
:param data_dir: Path to directory containing your training data in SQuAD style
|
||||||
:param train_filename: Filename of training data
|
:param train_filename: Filename of training data
|
||||||
:param dev_filename: Filename of dev / eval data
|
:param dev_filename: Filename of dev / eval data
|
||||||
@ -197,6 +203,10 @@ class FARMReader(BaseReader):
|
|||||||
"O2" (Almost FP16)
|
"O2" (Almost FP16)
|
||||||
"O3" (Pure FP16).
|
"O3" (Pure FP16).
|
||||||
See details on: https://nvidia.github.io/apex/amp.html
|
See details on: https://nvidia.github.io/apex/amp.html
|
||||||
|
:param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
|
||||||
|
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
||||||
|
:param checkpoint_every: save a train checkpoint after this many steps of training.
|
||||||
|
:param checkpoints_to_keep: maximum number of train checkpoints to save.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -251,7 +261,7 @@ class FARMReader(BaseReader):
|
|||||||
use_amp=use_amp,
|
use_amp=use_amp,
|
||||||
)
|
)
|
||||||
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||||
trainer = Trainer(
|
trainer = Trainer.create_or_load_checkpoint(
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
data_silo=data_silo,
|
data_silo=data_silo,
|
||||||
@ -261,10 +271,12 @@ class FARMReader(BaseReader):
|
|||||||
evaluate_every=evaluate_every,
|
evaluate_every=evaluate_every,
|
||||||
device=device,
|
device=device,
|
||||||
use_amp=use_amp,
|
use_amp=use_amp,
|
||||||
disable_tqdm=not self.progress_bar
|
disable_tqdm=not self.progress_bar,
|
||||||
|
checkpoint_root_dir=Path(checkpoint_root_dir),
|
||||||
|
checkpoint_every=checkpoint_every,
|
||||||
|
checkpoints_to_keep=checkpoints_to_keep,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 5. Let it grow!
|
# 5. Let it grow!
|
||||||
self.inferencer.model = trainer.train()
|
self.inferencer.model = trainer.train()
|
||||||
self.save(Path(save_dir))
|
self.save(Path(save_dir))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user