mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
Migrating to use native Pytorch AMP (#2827)
* Started making changes to use native Pytorch AMP * Updated compute_loss functions to use torch.cuda.amp.autocast * Updating docstrings * Add use_amp to trainer_checkpoint * Removed mentions of apex and started to add the necessary warnings * Removing unused instances of use_amp variable * Added fast training test for FARMReader. Needed to add max_query_length as a parameter in FARMReader.__init__ and FARMReader.train * Make max_query_length optional in FARMReader.train * Update lg Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Co-authored-by: agnieszka-m <amarzec13@gmail.com>
This commit is contained in:
parent
35e9ff26cc
commit
e84fae2894
@ -143,7 +143,7 @@ Computes Transformer-based similarity of predicted answer to gold labels to deri
|
||||
|
||||
Returns per QA pair a) the similarity of the most likely prediction (top 1) to all available gold labels
|
||||
b) the highest similarity of all predictions to gold labels
|
||||
c) a matrix consisting of the similarities of all the predicitions compared to all gold labels
|
||||
c) a matrix consisting of the similarities of all the predictions compared to all gold labels
|
||||
|
||||
**Arguments**:
|
||||
|
||||
|
||||
@ -149,7 +149,7 @@ def train(data_dir: str,
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: str = None,
|
||||
use_amp: bool = False,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
@ -193,14 +193,10 @@ Note that the evaluation report is logged at evaluation level INFO while Haystac
|
||||
- `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing.
|
||||
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
|
||||
Set to None to use all CPU cores minus one.
|
||||
- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Available options:
|
||||
None (Don't use AMP)
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
- `checkpoint_root_dir`: The Path of a directory where all train checkpoints are saved. For each individual
|
||||
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
||||
- `checkpoint_every`: Save a train checkpoint after this many steps of training.
|
||||
@ -237,7 +233,7 @@ def distil_prediction_layer_from(
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: str = None,
|
||||
use_amp: bool = False,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
@ -284,7 +280,7 @@ A list containing torch device objects and/or strings is supported (For example
|
||||
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
||||
parameter is not used and a single cpu device is used for inference.
|
||||
- `student_batch_size`: Number of samples the student model receives in one batch for training
|
||||
- `student_batch_size`: Number of samples the teacher model receives in one batch for distillation
|
||||
- `teacher_batch_size`: Number of samples the teacher model receives in one batch for distillation
|
||||
- `n_epochs`: Number of iterations on the whole training data set
|
||||
- `learning_rate`: Learning rate of the optimizer
|
||||
- `max_seq_len`: Maximum text length (in tokens). Everything longer gets cut down.
|
||||
@ -296,14 +292,10 @@ Options for different schedules are available in FARM.
|
||||
- `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing.
|
||||
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
|
||||
Set to None to use all CPU cores minus one.
|
||||
- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Available options:
|
||||
None (Don't use AMP)
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
- `checkpoint_root_dir`: the Path of directory where all train checkpoints are saved. For each individual
|
||||
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
||||
- `checkpoint_every`: save a train checkpoint after this many steps of training.
|
||||
@ -347,7 +339,7 @@ def distil_intermediate_layers_from(
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: str = None,
|
||||
use_amp: bool = False,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
@ -389,8 +381,7 @@ that gets split off from training data for eval.
|
||||
A list containing torch device objects and/or strings is supported (For example
|
||||
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
||||
parameter is not used and a single cpu device is used for inference.
|
||||
- `student_batch_size`: Number of samples the student model receives in one batch for training
|
||||
- `student_batch_size`: Number of samples the teacher model receives in one batch for distillation
|
||||
- `batch_size`: Number of samples the student model and teacher model receives in one batch for training
|
||||
- `n_epochs`: Number of iterations on the whole training data set
|
||||
- `learning_rate`: Learning rate of the optimizer
|
||||
- `max_seq_len`: Maximum text length (in tokens). Everything longer gets cut down.
|
||||
@ -402,21 +393,16 @@ Options for different schedules are available in FARM.
|
||||
- `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing.
|
||||
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
|
||||
Set to None to use all CPU cores minus one.
|
||||
- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Available options:
|
||||
None (Don't use AMP)
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
- `checkpoint_root_dir`: the Path of directory where all train checkpoints are saved. For each individual
|
||||
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
||||
- `checkpoint_every`: save a train checkpoint after this many steps of training.
|
||||
- `checkpoints_to_keep`: maximum number of train checkpoints to save.
|
||||
- `caching`: whether or not to use caching for preprocessed dataset and teacher logits
|
||||
- `cache_path`: Path to cache the preprocessed dataset and teacher logits
|
||||
- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
||||
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
|
||||
- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
||||
- `processor`: The processor to use for preprocessing. If None, the default SquadProcessor is used.
|
||||
@ -663,7 +649,7 @@ Example:
|
||||
**Arguments**:
|
||||
|
||||
- `question`: Question string
|
||||
- `documents`: List of documents as string type
|
||||
- `texts`: A list of Document texts as a string type
|
||||
- `top_k`: The maximum number of answers to return
|
||||
|
||||
**Returns**:
|
||||
|
||||
@ -946,7 +946,7 @@ def train(data_dir: str,
|
||||
weight_decay: float = 0.0,
|
||||
num_warmup_steps: int = 100,
|
||||
grad_acc_steps: int = 1,
|
||||
use_amp: str = None,
|
||||
use_amp: bool = False,
|
||||
optimizer_name: str = "AdamW",
|
||||
optimizer_correct_bias: bool = True,
|
||||
save_dir: str = "../saved_models/dpr",
|
||||
@ -984,12 +984,10 @@ you should use the file_system strategy.
|
||||
- `epsilon`: epsilon parameter of optimizer
|
||||
- `weight_decay`: weight decay parameter of optimizer
|
||||
- `grad_acc_steps`: number of steps to accumulate gradient over before back-propagation is done
|
||||
- `use_amp`: Whether to use automatic mixed precision (AMP) or not. The options are:
|
||||
"O0" (FP32)
|
||||
"O1" (Mixed Precision)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
For more information, refer to: https://nvidia.github.io/apex/amp.html
|
||||
- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
- `optimizer_name`: what optimizer to use (default: AdamW)
|
||||
- `num_warmup_steps`: number of warmup steps
|
||||
- `optimizer_correct_bias`: Whether to correct bias in optimizer
|
||||
@ -1305,7 +1303,7 @@ def train(data_dir: str,
|
||||
weight_decay: float = 0.0,
|
||||
num_warmup_steps: int = 100,
|
||||
grad_acc_steps: int = 1,
|
||||
use_amp: str = None,
|
||||
use_amp: bool = False,
|
||||
optimizer_name: str = "AdamW",
|
||||
optimizer_correct_bias: bool = True,
|
||||
save_dir: str = "../saved_models/mm_retrieval",
|
||||
@ -1345,12 +1343,10 @@ very similar (high score by BM25) to query but do not contain the answer)-
|
||||
- `epsilon`: Epsilon parameter of optimizer.
|
||||
- `weight_decay`: Weight decay parameter of optimizer.
|
||||
- `grad_acc_steps`: Number of steps to accumulate gradient over before back-propagation is done.
|
||||
- `use_amp`: Whether to use automatic mixed precision (AMP) or not. The options are:
|
||||
"O0" (FP32)
|
||||
"O1" (Mixed Precision)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
For more information, refer to: https://nvidia.github.io/apex/amp.html
|
||||
- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
- `optimizer_name`: What optimizer to use (default: TransformersAdamW).
|
||||
- `num_warmup_steps`: Number of warmup steps.
|
||||
- `optimizer_correct_bias`: Whether to correct bias in optimizer.
|
||||
|
||||
@ -72,7 +72,7 @@ class Processor(ABC):
|
||||
:param dev_filename: The name of the file containing the dev data. If None and 0.0 < dev_split < 1.0 the dev set
|
||||
will be a slice of the train set.
|
||||
:param test_filename: The name of the file containing test data.
|
||||
:param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None
|
||||
:param dev_split: The proportion of the train set that will be sliced. Only works if `dev_filename` is set to `None`.
|
||||
:param data_dir: The directory in which the train, test and perhaps dev files can be found.
|
||||
:param tasks: Tasks for which the processor shall extract labels from the input data.
|
||||
Usually this includes a single, default task, e.g. text classification.
|
||||
@ -137,7 +137,7 @@ class Processor(ABC):
|
||||
If None and 0.0 < dev_split < 1.0 the dev set
|
||||
will be a slice of the train set.
|
||||
:param test_filename: The name of the file containing test data.
|
||||
:param dev_split: The proportion of the train set that will sliced.
|
||||
:param dev_split: The proportion of the train set that will be sliced.
|
||||
Only works if dev_filename is set to None
|
||||
:param kwargs: placeholder for passing generic parameters
|
||||
:return: An instance of the specified processor.
|
||||
@ -217,6 +217,7 @@ class Processor(ABC):
|
||||
tokenizer_class=None,
|
||||
tokenizer_args=None,
|
||||
use_fast=True,
|
||||
max_query_length=64,
|
||||
**kwargs,
|
||||
):
|
||||
tokenizer_args = tokenizer_args or {}
|
||||
@ -238,6 +239,7 @@ class Processor(ABC):
|
||||
metric="squad",
|
||||
data_dir="data",
|
||||
doc_stride=doc_stride,
|
||||
max_query_length=max_query_length,
|
||||
)
|
||||
elif task_type == "embeddings":
|
||||
processor = InferenceProcessor(tokenizer=tokenizer, max_seq_len=max_seq_len)
|
||||
@ -396,7 +398,7 @@ class SquadProcessor(Processor):
|
||||
:param dev_filename: The name of the file containing the dev data. If None and 0.0 < dev_split < 1.0 the dev set
|
||||
will be a slice of the train set.
|
||||
:param test_filename: None
|
||||
:param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None
|
||||
:param dev_split: The proportion of the train set that will be sliced. Only works if `dev_filename` is set to `None`.
|
||||
:param doc_stride: When the document containing the answer is too long it gets split into part, strided by doc_stride
|
||||
:param max_query_length: Maximum length of the question (in number of subword tokens)
|
||||
:param proxies: proxy configuration to allow downloads of remote datasets.
|
||||
|
||||
@ -130,6 +130,7 @@ class Inferencer:
|
||||
multithreading_rust: bool = True,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
max_query_length: int = 64,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -178,6 +179,7 @@ class Inferencer:
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
Additional information can be found here
|
||||
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
:param max_query_length: Only QA: Maximum length of the question in number of tokens.
|
||||
:return: An instance of the Inferencer.
|
||||
"""
|
||||
if tokenizer_args is None:
|
||||
@ -228,6 +230,7 @@ class Inferencer:
|
||||
tokenizer_args=tokenizer_args,
|
||||
use_fast=use_fast,
|
||||
use_auth_token=use_auth_token,
|
||||
max_query_length=max_query_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -241,6 +244,8 @@ class Inferencer:
|
||||
"Please set a lower value for doc_stride (Suggestions: doc_stride=128, max_seq_len=384) "
|
||||
)
|
||||
processor.doc_stride = doc_stride
|
||||
if hasattr(processor, "max_query_length"):
|
||||
processor.max_query_length = max_query_length
|
||||
|
||||
return cls(
|
||||
model,
|
||||
|
||||
@ -280,7 +280,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
* vocab.txt vocab file for language model, turning text to Wordpiece Tokens
|
||||
|
||||
:param load_dir: Location where the AdaptiveModel is stored.
|
||||
:param device: To which device we want to sent the model, either torch.device("cpu") or torch.device("cuda").
|
||||
:param device: Specifies the device to which you want to send the model, either torch.device("cpu") or torch.device("cuda").
|
||||
:param strict: Whether to strictly enforce that the keys loaded from saved model match the ones in
|
||||
the PredictionHead (see torch.nn.module.load_state_dict()).
|
||||
:param processor: Processor to populate prediction head with information coming from tasks.
|
||||
|
||||
@ -14,29 +14,6 @@ from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from apex import amp # pylint: disable=import-error
|
||||
|
||||
logger.info("apex is available.")
|
||||
|
||||
try:
|
||||
from apex.parallel import convert_syncbn_model # pylint: disable=import-error
|
||||
|
||||
APEX_PARALLEL_AVAILABLE = True
|
||||
|
||||
logger.info("apex.parallel is available.")
|
||||
|
||||
except AttributeError:
|
||||
APEX_PARALLEL_AVAILABLE = False
|
||||
logger.debug("apex.parallel not found, won't use it. See https://nvidia.github.io/apex/parallel.html")
|
||||
|
||||
AMP_AVAILABLE = True
|
||||
|
||||
except ImportError:
|
||||
AMP_AVAILABLE = False
|
||||
APEX_PARALLEL_AVAILABLE = False
|
||||
logger.debug("apex not found, won't use it. See https://nvidia.github.io/apex/")
|
||||
|
||||
|
||||
class WrappedDataParallel(DataParallel):
|
||||
"""
|
||||
@ -57,7 +34,8 @@ class WrappedDataParallel(DataParallel):
|
||||
class WrappedDDP(DistributedDataParallel):
|
||||
"""
|
||||
A way of adapting attributes of underlying class to distributed mode. Same as in WrappedDataParallel above.
|
||||
Even when using distributed on a single machine with multiple GPUs, apex can speed up training significantly.
|
||||
Even when using distributed on a single computer with multiple GPUs, automatic mixed precision can speed up training
|
||||
significantly.
|
||||
Distributed code must be launched with "python -m torch.distributed.launch --nproc_per_node=1 run_script.py"
|
||||
"""
|
||||
|
||||
@ -79,7 +57,7 @@ def initialize_optimizer(
|
||||
distributed: bool = False,
|
||||
grad_acc_steps: int = 1,
|
||||
local_rank: int = -1,
|
||||
use_amp: Optional[str] = None,
|
||||
use_amp: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes an optimizer, a learning rate scheduler and converts the model if needed (e.g for mixed precision).
|
||||
@ -91,15 +69,13 @@ def initialize_optimizer(
|
||||
:param n_epochs: number of epochs for training
|
||||
:param device: Which hardware will be used by the optimizer. Either torch.device("cpu") or torch.device("cuda").
|
||||
:param learning_rate: Learning rate
|
||||
:param optimizer_opts: Dict to customize the optimizer. Choose any optimizer available from torch.optim, apex.optimizers or
|
||||
:param optimizer_opts: Dictionary to customize the optimizer. Choose any optimizer available from torch.optim or
|
||||
transformers.optimization by supplying the class name and the parameters for the constructor.
|
||||
Examples:
|
||||
1) AdamW from Transformers (Default):
|
||||
{"name": "AdamW", "correct_bias": False, "weight_decay": 0.01}
|
||||
2) SGD from pytorch:
|
||||
{"name": "SGD", "momentum": 0.0}
|
||||
3) FusedLAMB from apex:
|
||||
{"name": "FusedLAMB", "bias_correction": True}
|
||||
:param schedule_opts: Dict to customize the learning rate schedule.
|
||||
Choose any Schedule from Pytorch or Huggingface's Transformers by supplying the class name
|
||||
and the parameters needed by the constructor.
|
||||
@ -119,20 +95,17 @@ def initialize_optimizer(
|
||||
:param distributed: Whether training on distributed machines
|
||||
:param grad_acc_steps: Number of steps to accumulate gradients for. Helpful to mimic large batch_sizes on small machines.
|
||||
:param local_rank: rank of the machine in a distributed setting
|
||||
:param use_amp: Optimization level of nvidia's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Options:
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
:param use_amp: This option is deprecated. Haystack supports only PyTorch automatic mixed precision (AMP). We no longer support the Apex library. This means this function doesn't use `use_amp` any longer
|
||||
because it's not needed to initialize native Pytorch AMP. If you provide a value, you'll see a warning message.
|
||||
|
||||
:return: model, optimizer, scheduler
|
||||
"""
|
||||
if use_amp and not AMP_AVAILABLE:
|
||||
raise ImportError(
|
||||
f"Got use_amp = {use_amp}, but cannot find apex. "
|
||||
"Please install Apex if you want to make use of automatic mixed precision. "
|
||||
"https://github.com/NVIDIA/apex"
|
||||
if isinstance(use_amp, str):
|
||||
logger.warning(
|
||||
"Haystack supports only PyTorch automatic mixed precision. We no longer support the Apex library.\n"
|
||||
"This means that modeling.model.initialize_optimizer no longer uses use_amp since it is not needed\n"
|
||||
"to initialize native PyTorch automatic mixed precision. For more information, see [Optimization](https://haystack.deepset.ai/guides/optimization).\n"
|
||||
"In the future provide use_amp=True to use automatic mixed precision."
|
||||
)
|
||||
|
||||
if (schedule_opts is not None) and (not isinstance(schedule_opts, dict)):
|
||||
@ -161,13 +134,13 @@ def initialize_optimizer(
|
||||
schedule_opts["num_training_steps"] = num_train_optimization_steps
|
||||
|
||||
# Log params
|
||||
tracker.track_params({"use_amp": use_amp, "num_train_optimization_steps": schedule_opts["num_training_steps"]})
|
||||
tracker.track_params({"num_train_optimization_steps": schedule_opts["num_training_steps"]})
|
||||
|
||||
# Get optimizer from pytorch, transformers or apex
|
||||
# Get optimizer from pytorch or transformers
|
||||
optimizer = _get_optim(model, optimizer_opts)
|
||||
|
||||
# Adjust for parallel training + amp
|
||||
model, optimizer = optimize_model(model, device, local_rank, optimizer, distributed, use_amp)
|
||||
# Adjust for parallel training
|
||||
model, optimizer = optimize_model(model, device, local_rank, optimizer, distributed)
|
||||
|
||||
# Get learning rate schedule - moved below to supress warning
|
||||
scheduler = get_scheduler(optimizer, schedule_opts)
|
||||
@ -219,7 +192,7 @@ def _get_optim(model, opts: Dict):
|
||||
if weight_decay is not None:
|
||||
optimizable_parameters[0]["weight_decay"] = weight_decay # type: ignore
|
||||
|
||||
# Import optimizer by checking in order: torch, transformers, apex and local imports
|
||||
# Import optimizer by checking in order: torch, transformers and local imports
|
||||
try:
|
||||
optim_constructor = getattr(import_module("torch.optim"), optimizer_name)
|
||||
except AttributeError:
|
||||
@ -227,17 +200,14 @@ def _get_optim(model, opts: Dict):
|
||||
optim_constructor = getattr(import_module("transformers.optimization"), optimizer_name)
|
||||
except AttributeError:
|
||||
try:
|
||||
optim_constructor = getattr(import_module("apex.optimizers"), optimizer_name)
|
||||
# Workaround to allow loading AdamW from transformers
|
||||
# pytorch > 1.2 has now also a AdamW (but without the option to set bias_correction = False,
|
||||
# which is done in the original BERT implementation)
|
||||
optim_constructor = getattr(sys.modules[__name__], optimizer_name)
|
||||
except (AttributeError, ImportError):
|
||||
try:
|
||||
# Workaround to allow loading AdamW from transformers
|
||||
# pytorch > 1.2 has now also a AdamW (but without the option to set bias_correction = False,
|
||||
# which is done in the original BERT implementation)
|
||||
optim_constructor = getattr(sys.modules[__name__], optimizer_name)
|
||||
except (AttributeError, ImportError):
|
||||
raise AttributeError(
|
||||
f"Optimizer '{optimizer_name}' not found in 'torch', 'transformers', 'apex' or 'local imports"
|
||||
)
|
||||
raise AttributeError(
|
||||
f"We couldn't find optimizer '{optimizer_name}' in 'torch', 'transformers' or 'local imports'."
|
||||
)
|
||||
|
||||
return optim_constructor(optimizable_parameters)
|
||||
|
||||
@ -298,9 +268,9 @@ def optimize_model(
|
||||
model: "AdaptiveModel",
|
||||
device: torch.device,
|
||||
local_rank: int,
|
||||
optimizer=None,
|
||||
distributed: Optional[bool] = False,
|
||||
use_amp: Optional[str] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
distributed: bool = False,
|
||||
use_amp: bool = False,
|
||||
):
|
||||
"""
|
||||
Wraps MultiGPU or distributed usage around a model
|
||||
@ -310,24 +280,23 @@ def optimize_model(
|
||||
:param device: either torch.device("cpu") or torch.device("cuda"). Get the device from `initialize_device_settings()`
|
||||
:param distributed: Whether training on distributed machines
|
||||
:param local_rank: rank of the machine in a distributed setting
|
||||
:param use_amp: Optimization level of nvidia's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Options:
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
:param optimizer: torch optimizer
|
||||
:param use_amp: This option is deprecated. Haystack supports only PyTorch automatic mixed precision (AMP). We no longer support the Apex library. This means this function no longer uses `use_amp`
|
||||
because it's not needed to initialize native Pytorch AMP. If you provide a value, you'll see a warning message.
|
||||
|
||||
:return: model, optimizer
|
||||
"""
|
||||
model, optimizer = _init_amp(model, device, optimizer, use_amp)
|
||||
if isinstance(use_amp, str):
|
||||
logger.warning(
|
||||
"Haystack supports only PyTorch automatic mixed precision. We no longer support the Apex library.\n"
|
||||
"This means that modeling.model.initialize_optimizer no longer uses use_amp since it's not needed\n"
|
||||
"to initialize native PyTorch automatic mixed precision. For more information, see [Optimization](https://haystack.deepset.ai/guides/optimization).\n"
|
||||
"In the future, set `use_amp=True` to use automatic mixed precision."
|
||||
)
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
if distributed:
|
||||
if APEX_PARALLEL_AVAILABLE:
|
||||
model = convert_syncbn_model(model)
|
||||
logger.info("Multi-GPU Training via DistributedDataParallel and apex.parallel")
|
||||
else:
|
||||
logger.info("Multi-GPU Training via DistributedDataParallel")
|
||||
|
||||
# for some models DistributedDataParallel might complain about parameters
|
||||
# not contributing to loss. find_used_parameters remedies that.
|
||||
model = WrappedDDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
|
||||
@ -337,16 +306,3 @@ def optimize_model(
|
||||
logger.info("Multi-GPU Training via DataParallel")
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def _init_amp(model, device, optimizer=None, use_amp=None):
|
||||
model = model.to(device)
|
||||
if use_amp and optimizer:
|
||||
if AMP_AVAILABLE:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=use_amp)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Can't find AMP although you specificed to use amp with level {use_amp}. Will continue without AMP ..."
|
||||
)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
@ -297,10 +297,10 @@ class QAPred(Pred):
|
||||
|
||||
def _answers_to_json(self, ext_id, squad=False) -> List[Dict]:
|
||||
"""
|
||||
Convert all answers into a json format
|
||||
Convert all answers into a json format.
|
||||
|
||||
:param id: ID of the question document pair
|
||||
:param squad: If True, no_answers are represented by the empty string instead of "no_answer"
|
||||
:param ext_id: ID of the question document pair.
|
||||
:param squad: If True, no_answers are represented by the empty string instead of "no_answer".
|
||||
"""
|
||||
ret = []
|
||||
|
||||
|
||||
@ -23,13 +23,6 @@ from haystack.modeling.utils import GracefulKiller
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
from haystack.utils.early_stopping import EarlyStopping
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
||||
AMP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AMP_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -51,7 +44,7 @@ class Trainer:
|
||||
lr_schedule=None,
|
||||
evaluate_every: int = 100,
|
||||
eval_report: bool = True,
|
||||
use_amp: Optional[str] = None,
|
||||
use_amp: bool = False,
|
||||
grad_acc_steps: int = 1,
|
||||
local_rank: int = -1,
|
||||
early_stopping: Optional[EarlyStopping] = None,
|
||||
@ -77,8 +70,10 @@ class Trainer:
|
||||
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
|
||||
:param evaluate_every: Perform dev set evaluation after this many steps of training.
|
||||
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
|
||||
:param use_amp: Whether to use automatic mixed precision with Apex. One of the optimization levels must be chosen.
|
||||
"O1" is recommended in almost all cases.
|
||||
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
:param grad_acc_steps: Number of training steps for which the gradients should be accumulated.
|
||||
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
|
||||
:param local_rank: Local rank of process when distributed training via DDP is used.
|
||||
@ -88,10 +83,10 @@ class Trainer:
|
||||
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
|
||||
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
|
||||
a SIGTERM notifies to save the training state and subsequently the instance is terminated.
|
||||
:param checkpoint_every: save a train checkpoint after this many steps of training.
|
||||
:param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
|
||||
:param checkpoint_every: Save a training checkpoint after this many steps of training.
|
||||
:param checkpoint_root_dir: The directory Path where all training checkpoints are saved. For each individual
|
||||
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
||||
:param checkpoints_to_keep: maximum number of train checkpoints to save.
|
||||
:param checkpoints_to_keep: The maximum number of training checkpoints to save.
|
||||
:param from_epoch: the epoch number to start the training from. In the case when training resumes from a saved
|
||||
checkpoint, it is used to fast-forward training to the last epoch in the checkpoint.
|
||||
:param from_step: the step number to start the training from. In the case when training resumes from a saved
|
||||
@ -101,16 +96,30 @@ class Trainer:
|
||||
:param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments)
|
||||
:param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable
|
||||
"""
|
||||
amp_mapping = {"O0": False, "O1": True, "O2": True, "O3": True}
|
||||
self.model = model
|
||||
self.data_silo = data_silo
|
||||
self.epochs = int(epochs)
|
||||
if isinstance(use_amp, str):
|
||||
if use_amp in amp_mapping:
|
||||
logger.warning(
|
||||
"The Trainer only supports native PyTorch automatic mixed precision and no longer supports the Apex library.\n"
|
||||
f"Because you provided Apex optimization level {use_amp}, automatic mixed precision was set to {amp_mapping[use_amp]}.\n"
|
||||
"In the future, set `use_amp=True` to turn on automatic mixed precision."
|
||||
)
|
||||
use_amp = amp_mapping[use_amp]
|
||||
else:
|
||||
raise Exception(
|
||||
f"use_amp value {use_amp} is not supported. Please provide use_amp=True to turn on automatic mixed precision."
|
||||
)
|
||||
self.use_amp = use_amp
|
||||
self.optimizer = optimizer
|
||||
self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
|
||||
self.evaluate_every = evaluate_every
|
||||
self.eval_report = eval_report
|
||||
self.evaluator_test = evaluator_test
|
||||
self.n_gpu = n_gpu
|
||||
self.grad_acc_steps = grad_acc_steps
|
||||
self.use_amp = use_amp
|
||||
self.lr_schedule = lr_schedule
|
||||
self.device = device
|
||||
self.local_rank = local_rank
|
||||
@ -122,12 +131,6 @@ class Trainer:
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.test_result = None
|
||||
|
||||
if use_amp and not AMP_AVAILABLE:
|
||||
raise ImportError(
|
||||
f"Got use_amp = {use_amp}, but cannot find apex. "
|
||||
"Please install Apex if you want to make use of automatic mixed precision. "
|
||||
"https://github.com/NVIDIA/apex"
|
||||
)
|
||||
self.checkpoint_on_sigterm = checkpoint_on_sigterm
|
||||
if checkpoint_on_sigterm:
|
||||
self.sigterm_handler = GracefulKiller() # type: Optional[GracefulKiller]
|
||||
@ -203,7 +206,7 @@ class Trainer:
|
||||
|
||||
# Only for distributed training: we need to ensure that all ranks still have a batch left for training
|
||||
if self.local_rank != -1:
|
||||
if not self._all_ranks_have_data(has_data=1, step=step):
|
||||
if not self._all_ranks_have_data(has_data=True, step=step):
|
||||
early_break = True
|
||||
break
|
||||
|
||||
@ -297,47 +300,44 @@ class Trainer:
|
||||
else:
|
||||
module = self.model
|
||||
|
||||
if isinstance(module, AdaptiveModel):
|
||||
logits = self.model.forward(
|
||||
input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"]
|
||||
)
|
||||
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
||||
if isinstance(module, AdaptiveModel):
|
||||
logits = self.model.forward(
|
||||
input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"]
|
||||
)
|
||||
|
||||
elif isinstance(module, BiAdaptiveModel):
|
||||
logits = self.model.forward(
|
||||
query_input_ids=batch["query_input_ids"],
|
||||
query_segment_ids=batch["query_segment_ids"],
|
||||
query_attention_mask=batch["query_attention_mask"],
|
||||
passage_input_ids=batch["passage_input_ids"],
|
||||
passage_segment_ids=batch["passage_segment_ids"],
|
||||
passage_attention_mask=batch["passage_attention_mask"],
|
||||
)
|
||||
elif isinstance(module, BiAdaptiveModel):
|
||||
logits = self.model.forward(
|
||||
query_input_ids=batch["query_input_ids"],
|
||||
query_segment_ids=batch["query_segment_ids"],
|
||||
query_attention_mask=batch["query_attention_mask"],
|
||||
passage_input_ids=batch["passage_input_ids"],
|
||||
passage_segment_ids=batch["passage_segment_ids"],
|
||||
passage_attention_mask=batch["passage_attention_mask"],
|
||||
)
|
||||
|
||||
else:
|
||||
logits = self.model.forward(**batch)
|
||||
else:
|
||||
logits = self.model.forward(**batch)
|
||||
|
||||
per_sample_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
|
||||
return self.backward_propagate(per_sample_loss, step)
|
||||
per_sample_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
|
||||
loss = self.adjust_loss(per_sample_loss)
|
||||
return self.backward_propagate(loss, step)
|
||||
|
||||
def backward_propagate(self, loss: torch.Tensor, step: int):
|
||||
loss = self.adjust_loss(loss)
|
||||
if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0]:
|
||||
if self.local_rank in [-1, 0]:
|
||||
tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
|
||||
if self.log_learning_rate:
|
||||
tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
|
||||
if self.use_amp:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
if step % self.grad_acc_steps == 0:
|
||||
if self.max_grad_norm is not None:
|
||||
if self.use_amp:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||
self.optimizer.step()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
if self.lr_schedule:
|
||||
self.lr_schedule.step()
|
||||
@ -350,7 +350,7 @@ class Trainer:
|
||||
return loss
|
||||
|
||||
def log_params(self):
|
||||
params = {"epochs": self.epochs, "n_gpu": self.n_gpu, "device": self.device}
|
||||
params = {"epochs": self.epochs, "n_gpu": self.n_gpu, "device": self.device, "use_amp": self.use_amp}
|
||||
tracker.track_params(params)
|
||||
|
||||
@classmethod
|
||||
@ -545,6 +545,7 @@ class Trainer:
|
||||
"log_learning_rate": self.log_learning_rate,
|
||||
"log_loss_every": self.log_loss_every,
|
||||
"disable_tqdm": self.disable_tqdm,
|
||||
"use_amp": self.use_amp,
|
||||
}
|
||||
|
||||
return state_dict
|
||||
@ -580,7 +581,7 @@ class Trainer:
|
||||
class DistillationTrainer(Trainer):
|
||||
"""
|
||||
This trainer uses the teacher logits from DistillationDataSilo
|
||||
to compute a distillation loss in addtion to the loss based on the labels.
|
||||
to compute a distillation loss in addition to the loss based on the labels.
|
||||
|
||||
**Example**
|
||||
```python
|
||||
@ -608,7 +609,7 @@ class DistillationTrainer(Trainer):
|
||||
lr_schedule: Optional[_LRScheduler] = None,
|
||||
evaluate_every: int = 100,
|
||||
eval_report: bool = True,
|
||||
use_amp: Optional[str] = None,
|
||||
use_amp: bool = False,
|
||||
grad_acc_steps: int = 1,
|
||||
local_rank: int = -1,
|
||||
early_stopping: Optional[EarlyStopping] = None,
|
||||
@ -631,7 +632,6 @@ class DistillationTrainer(Trainer):
|
||||
"""
|
||||
:param optimizer: An optimizer object that determines the learning strategy to be used during training
|
||||
:param model: The model to be trained
|
||||
:param teacher_model: The teacher model used for distillation
|
||||
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
|
||||
:param epochs: How many times the training procedure will loop through the train dataset
|
||||
:param n_gpu: The number of gpus available for training and evaluation.
|
||||
@ -639,8 +639,10 @@ class DistillationTrainer(Trainer):
|
||||
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
|
||||
:param evaluate_every: Perform dev set evaluation after this many steps of training.
|
||||
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
|
||||
:param use_amp: Whether to use automatic mixed precision with Apex. One of the optimization levels must be chosen.
|
||||
"O1" is recommended in almost all cases.
|
||||
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
:param grad_acc_steps: Number of training steps for which the gradients should be accumulated.
|
||||
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
|
||||
:param local_rank: Local rank of process when distributed training via DDP is used.
|
||||
@ -709,23 +711,23 @@ class DistillationTrainer(Trainer):
|
||||
keys = list(batch.keys())
|
||||
keys = [key for key in keys if key.startswith("teacher_output")]
|
||||
teacher_logits = [batch.pop(key) for key in keys]
|
||||
|
||||
logits = self.model.forward(
|
||||
input_ids=batch.get("input_ids"),
|
||||
segment_ids=batch.get("segment_ids"),
|
||||
padding_mask=batch.get("padding_mask"),
|
||||
output_hidden_states=batch.get("output_hidden_states"),
|
||||
output_attentions=batch.get("output_attentions"),
|
||||
)
|
||||
|
||||
student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
|
||||
distillation_loss = self.distillation_loss_fn(
|
||||
student_logits=logits[0] / self.temperature, teacher_logits=teacher_logits[0] / self.temperature
|
||||
)
|
||||
combined_loss = distillation_loss * self.distillation_loss_weight * (self.temperature**2) + student_loss * (
|
||||
1 - self.distillation_loss_weight
|
||||
)
|
||||
return self.backward_propagate(combined_loss, step)
|
||||
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
||||
logits = self.model.forward(
|
||||
input_ids=batch.get("input_ids"),
|
||||
segment_ids=batch.get("segment_ids"),
|
||||
padding_mask=batch.get("padding_mask"),
|
||||
output_hidden_states=batch.get("output_hidden_states"),
|
||||
output_attentions=batch.get("output_attentions"),
|
||||
)
|
||||
student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
|
||||
distillation_loss = self.distillation_loss_fn(
|
||||
student_logits=logits[0] / self.temperature, teacher_logits=teacher_logits[0] / self.temperature
|
||||
)
|
||||
combined_loss = distillation_loss * self.distillation_loss_weight * (
|
||||
self.temperature**2
|
||||
) + student_loss * (1 - self.distillation_loss_weight)
|
||||
loss = self.adjust_loss(combined_loss)
|
||||
return self.backward_propagate(loss, step)
|
||||
|
||||
|
||||
class TinyBERTDistillationTrainer(Trainer):
|
||||
@ -761,7 +763,7 @@ class TinyBERTDistillationTrainer(Trainer):
|
||||
lr_schedule: Optional[_LRScheduler] = None,
|
||||
evaluate_every: int = 100,
|
||||
eval_report: bool = True,
|
||||
use_amp: Optional[str] = None,
|
||||
use_amp: bool = False,
|
||||
grad_acc_steps: int = 1,
|
||||
local_rank: int = -1,
|
||||
early_stopping: Optional[EarlyStopping] = None,
|
||||
@ -789,8 +791,10 @@ class TinyBERTDistillationTrainer(Trainer):
|
||||
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
|
||||
:param evaluate_every: Perform dev set evaluation after this many steps of training.
|
||||
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
|
||||
:param use_amp: Whether to use automatic mixed precision with Apex. One of the optimization levels must be chosen.
|
||||
"O1" is recommended in almost all cases.
|
||||
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
:param grad_acc_steps: Number of training steps for which the gradients should be accumulated.
|
||||
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
|
||||
:param local_rank: Local rank of process when distributed training via DDP is used.
|
||||
@ -849,16 +853,16 @@ class TinyBERTDistillationTrainer(Trainer):
|
||||
self.loss = DataParallel(self.loss).to(device)
|
||||
|
||||
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
|
||||
return self.backward_propagate(
|
||||
torch.sum(
|
||||
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
||||
loss = torch.sum(
|
||||
self.loss(
|
||||
input_ids=batch.get("input_ids"),
|
||||
segment_ids=batch.get("segment_ids"),
|
||||
padding_mask=batch.get("padding_mask"),
|
||||
)
|
||||
),
|
||||
step,
|
||||
)
|
||||
)
|
||||
loss = self.adjust_loss(loss)
|
||||
return self.backward_propagate(loss, step)
|
||||
|
||||
|
||||
class DistillationLoss(Module):
|
||||
|
||||
@ -62,7 +62,7 @@ def set_all_seeds(seed: int, deterministic_cudnn: bool = False) -> None:
|
||||
but might slow down your training (see https://pytorch.org/docs/stable/notes/randomness.html#cudnn) !
|
||||
|
||||
:param seed:number to use as seed
|
||||
:param deterministic_torch: Enable for full reproducibility when using CUDA. Caution: might slow down training.
|
||||
:param deterministic_cudnn: Enable for full reproducibility when using CUDA. Caution: might slow down training.
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
@ -407,7 +407,7 @@ def semantic_answer_similarity(
|
||||
Computes Transformer-based similarity of predicted answer to gold labels to derive a more meaningful metric than EM or F1.
|
||||
Returns per QA pair a) the similarity of the most likely prediction (top 1) to all available gold labels
|
||||
b) the highest similarity of all predictions to gold labels
|
||||
c) a matrix consisting of the similarities of all the predicitions compared to all gold labels
|
||||
c) a matrix consisting of the similarities of all the predictions compared to all gold labels
|
||||
|
||||
:param predictions: Predicted answers as list of multiple preds per question
|
||||
:param gold_labels: Labels as list of multiple possible answers per question
|
||||
|
||||
@ -68,6 +68,7 @@ class FARMReader(BaseReader):
|
||||
local_files_only=False,
|
||||
force_download=False,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
max_query_length: int = 64,
|
||||
):
|
||||
|
||||
"""
|
||||
@ -128,6 +129,7 @@ class FARMReader(BaseReader):
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
Additional information can be found here
|
||||
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
:param max_query_length: Maximum length of the question in number of tokens.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -151,6 +153,7 @@ class FARMReader(BaseReader):
|
||||
force_download=force_download,
|
||||
devices=self.devices,
|
||||
use_auth_token=use_auth_token,
|
||||
max_query_length=max_query_length,
|
||||
)
|
||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
||||
@ -159,6 +162,8 @@ class FARMReader(BaseReader):
|
||||
self.inferencer.model.prediction_heads[0].duplicate_filtering = duplicate_filtering
|
||||
self.inferencer.model.prediction_heads[0].use_confidence_scores_for_ranking = use_confidence_scores
|
||||
self.max_seq_len = max_seq_len
|
||||
self.doc_stride = doc_stride
|
||||
self.max_query_length = max_query_length
|
||||
self.progress_bar = progress_bar
|
||||
self.use_confidence_scores = use_confidence_scores
|
||||
self.confidence_threshold = confidence_threshold
|
||||
@ -181,7 +186,7 @@ class FARMReader(BaseReader):
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: Optional[str] = None,
|
||||
use_amp: bool = False,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
@ -196,6 +201,9 @@ class FARMReader(BaseReader):
|
||||
processor: Optional[Processor] = None,
|
||||
grad_acc_steps: int = 1,
|
||||
early_stopping: Optional[EarlyStopping] = None,
|
||||
distributed: bool = False,
|
||||
doc_stride: Optional[int] = None,
|
||||
max_query_length: Optional[int] = None,
|
||||
):
|
||||
if dev_filename:
|
||||
dev_split = 0
|
||||
@ -211,6 +219,10 @@ class FARMReader(BaseReader):
|
||||
devices = self.devices
|
||||
if max_seq_len is None:
|
||||
max_seq_len = self.max_seq_len
|
||||
if doc_stride is None:
|
||||
doc_stride = self.doc_stride
|
||||
if max_query_length is None:
|
||||
max_query_length = self.max_query_length
|
||||
|
||||
devices, n_gpu = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
|
||||
|
||||
@ -226,6 +238,8 @@ class FARMReader(BaseReader):
|
||||
processor = SquadProcessor(
|
||||
tokenizer=self.inferencer.processor.tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
max_query_length=max_query_length,
|
||||
doc_stride=doc_stride,
|
||||
label_list=label_list,
|
||||
metric=metric,
|
||||
train_filename=train_filename,
|
||||
@ -247,7 +261,7 @@ class FARMReader(BaseReader):
|
||||
device=devices[0],
|
||||
processor=processor,
|
||||
batch_size=batch_size,
|
||||
distributed=False,
|
||||
distributed=distributed,
|
||||
max_processes=num_processes,
|
||||
caching=caching,
|
||||
cache_path=cache_path,
|
||||
@ -256,7 +270,7 @@ class FARMReader(BaseReader):
|
||||
data_silo = DataSilo(
|
||||
processor=processor,
|
||||
batch_size=batch_size,
|
||||
distributed=False,
|
||||
distributed=distributed,
|
||||
max_processes=num_processes,
|
||||
caching=caching,
|
||||
cache_path=cache_path,
|
||||
@ -265,14 +279,13 @@ class FARMReader(BaseReader):
|
||||
# 3. Create an optimizer and pass the already initialized model
|
||||
model, optimizer, lr_schedule = initialize_optimizer(
|
||||
model=self.inferencer.model,
|
||||
# model=self.inferencer.model,
|
||||
learning_rate=learning_rate,
|
||||
schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion},
|
||||
n_batches=len(data_silo.loaders["train"]),
|
||||
n_epochs=n_epochs,
|
||||
device=devices[0],
|
||||
use_amp=use_amp,
|
||||
grad_acc_steps=grad_acc_steps,
|
||||
distributed=distributed,
|
||||
)
|
||||
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||
if tinybert:
|
||||
@ -360,7 +373,7 @@ class FARMReader(BaseReader):
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: Optional[str] = None,
|
||||
use_amp: bool = False,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
@ -368,6 +381,7 @@ class FARMReader(BaseReader):
|
||||
cache_path: Path = Path("cache/data_silo"),
|
||||
grad_acc_steps: int = 1,
|
||||
early_stopping: Optional[EarlyStopping] = None,
|
||||
max_query_length: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Fine-tune a model on a QA dataset. Options:
|
||||
@ -404,14 +418,10 @@ class FARMReader(BaseReader):
|
||||
:param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing.
|
||||
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
|
||||
Set to None to use all CPU cores minus one.
|
||||
:param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Available options:
|
||||
None (Don't use AMP)
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
:param checkpoint_root_dir: The Path of a directory where all train checkpoints are saved. For each individual
|
||||
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
||||
:param checkpoint_every: Save a train checkpoint after this many steps of training.
|
||||
@ -420,6 +430,7 @@ class FARMReader(BaseReader):
|
||||
:param cache_path: The Path to cache the preprocessed dataset.
|
||||
:param grad_acc_steps: The number of steps to accumulate gradients for before performing a backward pass.
|
||||
:param early_stopping: An initialized EarlyStopping object to control early stopping and saving of the best models.
|
||||
:param max_query_length: Maximum length of the question in number of tokens.
|
||||
:return: None
|
||||
"""
|
||||
return self._training_procedure(
|
||||
@ -446,6 +457,8 @@ class FARMReader(BaseReader):
|
||||
cache_path=cache_path,
|
||||
grad_acc_steps=grad_acc_steps,
|
||||
early_stopping=early_stopping,
|
||||
max_query_length=max_query_length,
|
||||
distributed=False,
|
||||
)
|
||||
|
||||
def distil_prediction_layer_from(
|
||||
@ -467,7 +480,7 @@ class FARMReader(BaseReader):
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: Optional[str] = None,
|
||||
use_amp: bool = False,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
@ -510,7 +523,7 @@ class FARMReader(BaseReader):
|
||||
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
||||
parameter is not used and a single cpu device is used for inference.
|
||||
:param student_batch_size: Number of samples the student model receives in one batch for training
|
||||
:param student_batch_size: Number of samples the teacher model receives in one batch for distillation
|
||||
:param teacher_batch_size: Number of samples the teacher model receives in one batch for distillation
|
||||
:param n_epochs: Number of iterations on the whole training data set
|
||||
:param learning_rate: Learning rate of the optimizer
|
||||
:param max_seq_len: Maximum text length (in tokens). Everything longer gets cut down.
|
||||
@ -522,14 +535,10 @@ class FARMReader(BaseReader):
|
||||
:param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing.
|
||||
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
|
||||
Set to None to use all CPU cores minus one.
|
||||
:param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Available options:
|
||||
None (Don't use AMP)
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
:param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
|
||||
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
||||
:param checkpoint_every: save a train checkpoint after this many steps of training.
|
||||
@ -577,6 +586,7 @@ class FARMReader(BaseReader):
|
||||
temperature=temperature,
|
||||
grad_acc_steps=grad_acc_steps,
|
||||
early_stopping=early_stopping,
|
||||
distributed=False,
|
||||
)
|
||||
|
||||
def distil_intermediate_layers_from(
|
||||
@ -597,7 +607,7 @@ class FARMReader(BaseReader):
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: Optional[str] = None,
|
||||
use_amp: bool = False,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
@ -635,8 +645,7 @@ class FARMReader(BaseReader):
|
||||
A list containing torch device objects and/or strings is supported (For example
|
||||
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
||||
parameter is not used and a single cpu device is used for inference.
|
||||
:param student_batch_size: Number of samples the student model receives in one batch for training
|
||||
:param student_batch_size: Number of samples the teacher model receives in one batch for distillation
|
||||
:param batch_size: Number of samples the student model and teacher model receives in one batch for training
|
||||
:param n_epochs: Number of iterations on the whole training data set
|
||||
:param learning_rate: Learning rate of the optimizer
|
||||
:param max_seq_len: Maximum text length (in tokens). Everything longer gets cut down.
|
||||
@ -648,21 +657,16 @@ class FARMReader(BaseReader):
|
||||
:param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing.
|
||||
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
|
||||
Set to None to use all CPU cores minus one.
|
||||
:param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Available options:
|
||||
None (Don't use AMP)
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
:param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
|
||||
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
||||
:param checkpoint_every: save a train checkpoint after this many steps of training.
|
||||
:param checkpoints_to_keep: maximum number of train checkpoints to save.
|
||||
:param caching: whether or not to use caching for preprocessed dataset and teacher logits
|
||||
:param cache_path: Path to cache the preprocessed dataset and teacher logits
|
||||
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
||||
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
|
||||
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
||||
:param processor: The processor to use for preprocessing. If None, the default SquadProcessor is used.
|
||||
@ -700,6 +704,7 @@ class FARMReader(BaseReader):
|
||||
processor=processor,
|
||||
grad_acc_steps=grad_acc_steps,
|
||||
early_stopping=early_stopping,
|
||||
distributed=False,
|
||||
)
|
||||
|
||||
def update_parameters(
|
||||
@ -1297,7 +1302,7 @@ class FARMReader(BaseReader):
|
||||
```
|
||||
|
||||
:param question: Question string
|
||||
:param documents: List of documents as string type
|
||||
:param texts: A list of Document texts as a string type
|
||||
:param top_k: The maximum number of answers to return
|
||||
:return: Dict containing question and answers
|
||||
"""
|
||||
|
||||
@ -593,7 +593,7 @@ class DensePassageRetriever(DenseRetriever):
|
||||
weight_decay: float = 0.0,
|
||||
num_warmup_steps: int = 100,
|
||||
grad_acc_steps: int = 1,
|
||||
use_amp: Optional[str] = None,
|
||||
use_amp: bool = False,
|
||||
optimizer_name: str = "AdamW",
|
||||
optimizer_correct_bias: bool = True,
|
||||
save_dir: str = "../saved_models/dpr",
|
||||
@ -628,12 +628,10 @@ class DensePassageRetriever(DenseRetriever):
|
||||
:param epsilon: epsilon parameter of optimizer
|
||||
:param weight_decay: weight decay parameter of optimizer
|
||||
:param grad_acc_steps: number of steps to accumulate gradient over before back-propagation is done
|
||||
:param use_amp: Whether to use automatic mixed precision (AMP) or not. The options are:
|
||||
"O0" (FP32)
|
||||
"O1" (Mixed Precision)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
For more information, refer to: https://nvidia.github.io/apex/amp.html
|
||||
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
:param optimizer_name: what optimizer to use (default: AdamW)
|
||||
:param num_warmup_steps: number of warmup steps
|
||||
:param optimizer_correct_bias: Whether to correct bias in optimizer
|
||||
@ -687,7 +685,6 @@ class DensePassageRetriever(DenseRetriever):
|
||||
n_epochs=n_epochs,
|
||||
grad_acc_steps=grad_acc_steps,
|
||||
device=self.devices[0], # Only use first device while multi-gpu training is not implemented
|
||||
use_amp=use_amp,
|
||||
)
|
||||
|
||||
# 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||
@ -1236,7 +1233,7 @@ class TableTextRetriever(DenseRetriever):
|
||||
weight_decay: float = 0.0,
|
||||
num_warmup_steps: int = 100,
|
||||
grad_acc_steps: int = 1,
|
||||
use_amp: Optional[str] = None,
|
||||
use_amp: bool = False,
|
||||
optimizer_name: str = "AdamW",
|
||||
optimizer_correct_bias: bool = True,
|
||||
save_dir: str = "../saved_models/mm_retrieval",
|
||||
@ -1273,12 +1270,10 @@ class TableTextRetriever(DenseRetriever):
|
||||
:param epsilon: Epsilon parameter of optimizer.
|
||||
:param weight_decay: Weight decay parameter of optimizer.
|
||||
:param grad_acc_steps: Number of steps to accumulate gradient over before back-propagation is done.
|
||||
:param use_amp: Whether to use automatic mixed precision (AMP) or not. The options are:
|
||||
"O0" (FP32)
|
||||
"O1" (Mixed Precision)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
For more information, refer to: https://nvidia.github.io/apex/amp.html
|
||||
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
||||
training speed and reduce GPU memory usage.
|
||||
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
||||
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
||||
:param optimizer_name: What optimizer to use (default: TransformersAdamW).
|
||||
:param num_warmup_steps: Number of warmup steps.
|
||||
:param optimizer_correct_bias: Whether to correct bias in optimizer.
|
||||
@ -1327,7 +1322,6 @@ class TableTextRetriever(DenseRetriever):
|
||||
n_epochs=n_epochs,
|
||||
grad_acc_steps=grad_acc_steps,
|
||||
device=self.devices[0], # Only use first device while multi-gpu training is not implemented
|
||||
use_amp=use_amp,
|
||||
)
|
||||
|
||||
# 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||
|
||||
@ -1269,7 +1269,7 @@ class EvaluationResult:
|
||||
Answer metrics are:
|
||||
- exact_match (Did the query exactly return any gold answer? -> 1.0 or 0.0)
|
||||
- f1 (How well does the best matching returned results overlap with any gold answer on token basis?)
|
||||
- sas if a SAS model has bin provided during during pipeline.eval() (How semantically similar is the prediction to the gold answers?)
|
||||
- sas if a SAS model has been provided during pipeline.eval() (How semantically similar is the prediction to the gold answers?)
|
||||
"""
|
||||
multilabel_ids = answers["multilabel_id"].unique()
|
||||
# simulate top k retriever
|
||||
|
||||
@ -12,6 +12,8 @@ from haystack.schema import Document, Answer, Label, MultiLabel, Span
|
||||
from haystack.nodes.reader.base import BaseReader
|
||||
from haystack.nodes import FARMReader, TransformersReader
|
||||
|
||||
from ..conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
# TODO Fix bug in test_no_answer_output when using
|
||||
# @pytest.fixture(params=["farm", "transformers"])
|
||||
@ -405,3 +407,31 @@ def test_no_answer_reader_skips_empty_documents(no_answer_reader):
|
||||
)
|
||||
assert predictions["answers"][0][0].answer == "" # Return no_answer for 1st query as document is empty
|
||||
assert predictions["answers"][1][1].answer == "Carla" # answer given for 2nd query as usual
|
||||
|
||||
|
||||
def test_reader_training(tmp_path):
|
||||
max_seq_len = 16
|
||||
max_query_length = 8
|
||||
reader = FARMReader(
|
||||
model_name_or_path="deepset/tinyroberta-squad2",
|
||||
use_gpu=False,
|
||||
num_processes=0,
|
||||
max_seq_len=max_seq_len,
|
||||
doc_stride=2,
|
||||
max_query_length=max_query_length,
|
||||
)
|
||||
|
||||
save_dir = f"{tmp_path}/test_dpr_training"
|
||||
reader.train(
|
||||
data_dir=str(SAMPLES_PATH / "squad"),
|
||||
train_filename="tiny.json",
|
||||
dev_filename="tiny.json",
|
||||
test_filename="tiny.json",
|
||||
n_epochs=1,
|
||||
batch_size=1,
|
||||
grad_acc_steps=1,
|
||||
save_dir=save_dir,
|
||||
evaluate_every=2,
|
||||
max_seq_len=max_seq_len,
|
||||
max_query_length=max_query_length,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user