mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 09:49:48 +00:00 
			
		
		
		
	Add automatic mixed precision (AMP) support for reader training (#463)
* Added automatic mixed precision (AMP) support for reader training * Added clearer comments on docstring
This commit is contained in:
		
							parent
							
								
									955e6f7b3a
								
							
						
					
					
						commit
						3caaf99dcb
					
				| @ -119,7 +119,8 @@ class FARMReader(BaseReader): | ||||
|         dev_split: float = 0, | ||||
|         evaluate_every: int = 300, | ||||
|         save_dir: Optional[str] = None, | ||||
|         num_processes: Optional[int] = None | ||||
|         num_processes: Optional[int] = None, | ||||
|         use_amp: str = None, | ||||
|     ): | ||||
|         """ | ||||
|         Fine-tune a model on a QA dataset. Options: | ||||
| @ -146,6 +147,14 @@ class FARMReader(BaseReader): | ||||
|         :param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing. | ||||
|                               Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. | ||||
|                               Set to None to use all CPU cores minus one. | ||||
|         :param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. | ||||
|                         Available options: | ||||
|                         None (Don't use AMP) | ||||
|                         "O0" (Normal FP32 training) | ||||
|                         "O1" (Mixed Precision => Recommended) | ||||
|                         "O2" (Almost FP16) | ||||
|                         "O3" (Pure FP16). | ||||
|                         See details on: https://nvidia.github.io/apex/amp.html | ||||
|         :return: None | ||||
|         """ | ||||
| 
 | ||||
| @ -164,7 +173,7 @@ class FARMReader(BaseReader): | ||||
|         if max_seq_len is None: | ||||
|             max_seq_len = self.max_seq_len | ||||
| 
 | ||||
|         device, n_gpu = initialize_device_settings(use_cuda=use_gpu) | ||||
|         device, n_gpu = initialize_device_settings(use_cuda=use_gpu,use_amp=use_amp) | ||||
| 
 | ||||
|         if not save_dir: | ||||
|             save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}" | ||||
| @ -203,7 +212,8 @@ class FARMReader(BaseReader): | ||||
|             schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion}, | ||||
|             n_batches=len(data_silo.loaders["train"]), | ||||
|             n_epochs=n_epochs, | ||||
|             device=device | ||||
|             device=device, | ||||
|             use_amp=use_amp, | ||||
|         ) | ||||
|         # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time | ||||
|         trainer = Trainer( | ||||
| @ -215,6 +225,7 @@ class FARMReader(BaseReader): | ||||
|             lr_schedule=lr_schedule, | ||||
|             evaluate_every=evaluate_every, | ||||
|             device=device, | ||||
|             use_amp=use_amp, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Antonio Lanza
						Antonio Lanza