Add automatic mixed precision (AMP) support for reader training (#463)

* Added automatic mixed precision (AMP) support for reader training

* Added clearer comments on docstring
This commit is contained in:
Antonio Lanza 2020-10-12 21:53:05 +02:00 committed by GitHub
parent 955e6f7b3a
commit 3caaf99dcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -119,7 +119,8 @@ class FARMReader(BaseReader):
dev_split: float = 0,
evaluate_every: int = 300,
save_dir: Optional[str] = None,
num_processes: Optional[int] = None
num_processes: Optional[int] = None,
use_amp: str = None,
):
"""
Fine-tune a model on a QA dataset. Options:
@ -146,6 +147,14 @@ class FARMReader(BaseReader):
:param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing.
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
Set to None to use all CPU cores minus one.
:param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
Available options:
None (Don't use AMP)
"O0" (Normal FP32 training)
"O1" (Mixed Precision => Recommended)
"O2" (Almost FP16)
"O3" (Pure FP16).
See details on: https://nvidia.github.io/apex/amp.html
:return: None
"""
@ -164,7 +173,7 @@ class FARMReader(BaseReader):
if max_seq_len is None:
max_seq_len = self.max_seq_len
device, n_gpu = initialize_device_settings(use_cuda=use_gpu)
device, n_gpu = initialize_device_settings(use_cuda=use_gpu,use_amp=use_amp)
if not save_dir:
save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}"
@ -203,7 +212,8 @@ class FARMReader(BaseReader):
schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion},
n_batches=len(data_silo.loaders["train"]),
n_epochs=n_epochs,
device=device
device=device,
use_amp=use_amp,
)
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
trainer = Trainer(
@ -215,6 +225,7 @@ class FARMReader(BaseReader):
lr_schedule=lr_schedule,
evaluate_every=evaluate_every,
device=device,
use_amp=use_amp,
)