diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 00ff0e284..36047460b 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -119,7 +119,8 @@ class FARMReader(BaseReader): dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, - num_processes: Optional[int] = None + num_processes: Optional[int] = None, + use_amp: str = None, ): """ Fine-tune a model on a QA dataset. Options: @@ -146,6 +147,14 @@ class FARMReader(BaseReader): :param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing. Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. Set to None to use all CPU cores minus one. + :param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. + Available options: + None (Don't use AMP) + "O0" (Normal FP32 training) + "O1" (Mixed Precision => Recommended) + "O2" (Almost FP16) + "O3" (Pure FP16). + See details on: https://nvidia.github.io/apex/amp.html :return: None """ @@ -164,7 +173,7 @@ class FARMReader(BaseReader): if max_seq_len is None: max_seq_len = self.max_seq_len - device, n_gpu = initialize_device_settings(use_cuda=use_gpu) + device, n_gpu = initialize_device_settings(use_cuda=use_gpu,use_amp=use_amp) if not save_dir: save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}" @@ -203,7 +212,8 @@ class FARMReader(BaseReader): schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion}, n_batches=len(data_silo.loaders["train"]), n_epochs=n_epochs, - device=device + device=device, + use_amp=use_amp, ) # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time trainer = Trainer( @@ -215,6 +225,7 @@ class FARMReader(BaseReader): lr_schedule=lr_schedule, evaluate_every=evaluate_every, device=device, + use_amp=use_amp, )