mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-04 03:39:31 +00:00 
			
		
		
		
	Add checkpointing for reader.train() to allow stopping + resuming training (#1554)
* adding create checkpoint feature for train function in farm reader * added arguments for create_or_load_checkpoint function * accessing class method inside Trainer class * added default value for checkpoint_root_dir and checkpoint_every, checkpoints_to_keep as arguments for reader.train() * change in default value for checkpoint_root_dir and checkpoint_every * update docstring and add Path conversion Co-authored-by: girish.koushik <girish.koushik@diatoz.com> Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
		
							parent
							
								
									575e64333c
								
							
						
					
					
						commit
						5a6285f23f
					
				@ -163,13 +163,19 @@ class FARMReader(BaseReader):
 | 
			
		||||
        save_dir: Optional[str] = None,
 | 
			
		||||
        num_processes: Optional[int] = None,
 | 
			
		||||
        use_amp: str = None,
 | 
			
		||||
        checkpoint_root_dir: Path = Path("model_checkpoints"),
 | 
			
		||||
        checkpoint_every: Optional[int] = None,
 | 
			
		||||
        checkpoints_to_keep: int = 3,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Fine-tune a model on a QA dataset. Options:
 | 
			
		||||
 | 
			
		||||
        - Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
 | 
			
		||||
        - Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
 | 
			
		||||
 | 
			
		||||
         
 | 
			
		||||
        Checkpoints can be stored via setting `checkpoint_every` to a custom number of steps. 
 | 
			
		||||
        If any checkpoints are stored, a subsequent run of train() will resume training from the latest available checkpoint.
 | 
			
		||||
         
 | 
			
		||||
        :param data_dir: Path to directory containing your training data in SQuAD style
 | 
			
		||||
        :param train_filename: Filename of training data
 | 
			
		||||
        :param dev_filename: Filename of dev / eval data
 | 
			
		||||
@ -197,6 +203,10 @@ class FARMReader(BaseReader):
 | 
			
		||||
                        "O2" (Almost FP16)
 | 
			
		||||
                        "O3" (Pure FP16).
 | 
			
		||||
                        See details on: https://nvidia.github.io/apex/amp.html
 | 
			
		||||
        :param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
 | 
			
		||||
               checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
 | 
			
		||||
        :param checkpoint_every: save a train checkpoint after this many steps of training.
 | 
			
		||||
        :param checkpoints_to_keep: maximum number of train checkpoints to save.
 | 
			
		||||
        :return: None
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
@ -251,7 +261,7 @@ class FARMReader(BaseReader):
 | 
			
		||||
            use_amp=use_amp,
 | 
			
		||||
        )
 | 
			
		||||
        # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
 | 
			
		||||
        trainer = Trainer(
 | 
			
		||||
        trainer = Trainer.create_or_load_checkpoint(
 | 
			
		||||
            model=model,
 | 
			
		||||
            optimizer=optimizer,
 | 
			
		||||
            data_silo=data_silo,
 | 
			
		||||
@ -261,10 +271,12 @@ class FARMReader(BaseReader):
 | 
			
		||||
            evaluate_every=evaluate_every,
 | 
			
		||||
            device=device,
 | 
			
		||||
            use_amp=use_amp,
 | 
			
		||||
            disable_tqdm=not self.progress_bar
 | 
			
		||||
            disable_tqdm=not self.progress_bar,
 | 
			
		||||
            checkpoint_root_dir=Path(checkpoint_root_dir),
 | 
			
		||||
            checkpoint_every=checkpoint_every,
 | 
			
		||||
            checkpoints_to_keep=checkpoints_to_keep,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # 5. Let it grow!
 | 
			
		||||
        self.inferencer.model = trainer.train()
 | 
			
		||||
        self.save(Path(save_dir))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user