mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 20:03:38 +00:00
Add automatic mixed precision (AMP) support for reader training (#463)
* Added automatic mixed precision (AMP) support for reader training * Added clearer comments on docstring
This commit is contained in:
parent
955e6f7b3a
commit
3caaf99dcb
@ -119,7 +119,8 @@ class FARMReader(BaseReader):
|
||||
dev_split: float = 0,
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: str = None,
|
||||
):
|
||||
"""
|
||||
Fine-tune a model on a QA dataset. Options:
|
||||
@ -146,6 +147,14 @@ class FARMReader(BaseReader):
|
||||
:param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing.
|
||||
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
|
||||
Set to None to use all CPU cores minus one.
|
||||
:param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Available options:
|
||||
None (Don't use AMP)
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
:return: None
|
||||
"""
|
||||
|
||||
@ -164,7 +173,7 @@ class FARMReader(BaseReader):
|
||||
if max_seq_len is None:
|
||||
max_seq_len = self.max_seq_len
|
||||
|
||||
device, n_gpu = initialize_device_settings(use_cuda=use_gpu)
|
||||
device, n_gpu = initialize_device_settings(use_cuda=use_gpu,use_amp=use_amp)
|
||||
|
||||
if not save_dir:
|
||||
save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}"
|
||||
@ -203,7 +212,8 @@ class FARMReader(BaseReader):
|
||||
schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion},
|
||||
n_batches=len(data_silo.loaders["train"]),
|
||||
n_epochs=n_epochs,
|
||||
device=device
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
)
|
||||
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||
trainer = Trainer(
|
||||
@ -215,6 +225,7 @@ class FARMReader(BaseReader):
|
||||
lr_schedule=lr_schedule,
|
||||
evaluate_every=evaluate_every,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user