diff --git a/docs/_src/api/api/evaluation.md b/docs/_src/api/api/evaluation.md index a2ed13102..2a3f5054b 100644 --- a/docs/_src/api/api/evaluation.md +++ b/docs/_src/api/api/evaluation.md @@ -143,7 +143,7 @@ Computes Transformer-based similarity of predicted answer to gold labels to deri Returns per QA pair a) the similarity of the most likely prediction (top 1) to all available gold labels b) the highest similarity of all predictions to gold labels - c) a matrix consisting of the similarities of all the predicitions compared to all gold labels + c) a matrix consisting of the similarities of all the predictions compared to all gold labels **Arguments**: diff --git a/docs/_src/api/api/reader.md b/docs/_src/api/api/reader.md index f142e8d75..7da25f196 100644 --- a/docs/_src/api/api/reader.md +++ b/docs/_src/api/api/reader.md @@ -149,7 +149,7 @@ def train(data_dir: str, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, - use_amp: str = None, + use_amp: bool = False, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, @@ -193,14 +193,10 @@ Note that the evaluation report is logged at evaluation level INFO while Haystac - `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing. Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. Set to None to use all CPU cores minus one. -- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. -Available options: -None (Don't use AMP) -"O0" (Normal FP32 training) -"O1" (Mixed Precision => Recommended) -"O2" (Almost FP16) -"O3" (Pure FP16). -See details on: https://nvidia.github.io/apex/amp.html +- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve +training speed and reduce GPU memory usage. +For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] +and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. - `checkpoint_root_dir`: The Path of a directory where all train checkpoints are saved. For each individual checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created. - `checkpoint_every`: Save a train checkpoint after this many steps of training. @@ -237,7 +233,7 @@ def distil_prediction_layer_from( evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, - use_amp: str = None, + use_amp: bool = False, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, @@ -284,7 +280,7 @@ A list containing torch device objects and/or strings is supported (For example [torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices parameter is not used and a single cpu device is used for inference. - `student_batch_size`: Number of samples the student model receives in one batch for training -- `student_batch_size`: Number of samples the teacher model receives in one batch for distillation +- `teacher_batch_size`: Number of samples the teacher model receives in one batch for distillation - `n_epochs`: Number of iterations on the whole training data set - `learning_rate`: Learning rate of the optimizer - `max_seq_len`: Maximum text length (in tokens). Everything longer gets cut down. @@ -296,14 +292,10 @@ Options for different schedules are available in FARM. - `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing. Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. Set to None to use all CPU cores minus one. -- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. -Available options: -None (Don't use AMP) -"O0" (Normal FP32 training) -"O1" (Mixed Precision => Recommended) -"O2" (Almost FP16) -"O3" (Pure FP16). -See details on: https://nvidia.github.io/apex/amp.html +- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve +training speed and reduce GPU memory usage. +For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] +and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. - `checkpoint_root_dir`: the Path of directory where all train checkpoints are saved. For each individual checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created. - `checkpoint_every`: save a train checkpoint after this many steps of training. @@ -347,7 +339,7 @@ def distil_intermediate_layers_from( evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, - use_amp: str = None, + use_amp: bool = False, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, @@ -389,8 +381,7 @@ that gets split off from training data for eval. A list containing torch device objects and/or strings is supported (For example [torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices parameter is not used and a single cpu device is used for inference. -- `student_batch_size`: Number of samples the student model receives in one batch for training -- `student_batch_size`: Number of samples the teacher model receives in one batch for distillation +- `batch_size`: Number of samples the student model and teacher model receives in one batch for training - `n_epochs`: Number of iterations on the whole training data set - `learning_rate`: Learning rate of the optimizer - `max_seq_len`: Maximum text length (in tokens). Everything longer gets cut down. @@ -402,21 +393,16 @@ Options for different schedules are available in FARM. - `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing. Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. Set to None to use all CPU cores minus one. -- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. -Available options: -None (Don't use AMP) -"O0" (Normal FP32 training) -"O1" (Mixed Precision => Recommended) -"O2" (Almost FP16) -"O3" (Pure FP16). -See details on: https://nvidia.github.io/apex/amp.html +- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve +training speed and reduce GPU memory usage. +For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] +and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. - `checkpoint_root_dir`: the Path of directory where all train checkpoints are saved. For each individual checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created. - `checkpoint_every`: save a train checkpoint after this many steps of training. - `checkpoints_to_keep`: maximum number of train checkpoints to save. - `caching`: whether or not to use caching for preprocessed dataset and teacher logits - `cache_path`: Path to cache the preprocessed dataset and teacher logits -- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important. - `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits) - `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model. - `processor`: The processor to use for preprocessing. If None, the default SquadProcessor is used. @@ -663,7 +649,7 @@ Example: **Arguments**: - `question`: Question string -- `documents`: List of documents as string type +- `texts`: A list of Document texts as a string type - `top_k`: The maximum number of answers to return **Returns**: diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index 65dd156cc..73a178e46 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -946,7 +946,7 @@ def train(data_dir: str, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, - use_amp: str = None, + use_amp: bool = False, optimizer_name: str = "AdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr", @@ -984,12 +984,10 @@ you should use the file_system strategy. - `epsilon`: epsilon parameter of optimizer - `weight_decay`: weight decay parameter of optimizer - `grad_acc_steps`: number of steps to accumulate gradient over before back-propagation is done -- `use_amp`: Whether to use automatic mixed precision (AMP) or not. The options are: -"O0" (FP32) -"O1" (Mixed Precision) -"O2" (Almost FP16) -"O3" (Pure FP16). -For more information, refer to: https://nvidia.github.io/apex/amp.html +- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve +training speed and reduce GPU memory usage. +For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] +and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. - `optimizer_name`: what optimizer to use (default: AdamW) - `num_warmup_steps`: number of warmup steps - `optimizer_correct_bias`: Whether to correct bias in optimizer @@ -1305,7 +1303,7 @@ def train(data_dir: str, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, - use_amp: str = None, + use_amp: bool = False, optimizer_name: str = "AdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/mm_retrieval", @@ -1345,12 +1343,10 @@ very similar (high score by BM25) to query but do not contain the answer)- - `epsilon`: Epsilon parameter of optimizer. - `weight_decay`: Weight decay parameter of optimizer. - `grad_acc_steps`: Number of steps to accumulate gradient over before back-propagation is done. -- `use_amp`: Whether to use automatic mixed precision (AMP) or not. The options are: -"O0" (FP32) -"O1" (Mixed Precision) -"O2" (Almost FP16) -"O3" (Pure FP16). -For more information, refer to: https://nvidia.github.io/apex/amp.html +- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve +training speed and reduce GPU memory usage. +For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] +and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. - `optimizer_name`: What optimizer to use (default: TransformersAdamW). - `num_warmup_steps`: Number of warmup steps. - `optimizer_correct_bias`: Whether to correct bias in optimizer. diff --git a/haystack/modeling/data_handler/processor.py b/haystack/modeling/data_handler/processor.py index 4fba61119..57b8cce0b 100644 --- a/haystack/modeling/data_handler/processor.py +++ b/haystack/modeling/data_handler/processor.py @@ -72,7 +72,7 @@ class Processor(ABC): :param dev_filename: The name of the file containing the dev data. If None and 0.0 < dev_split < 1.0 the dev set will be a slice of the train set. :param test_filename: The name of the file containing test data. - :param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None + :param dev_split: The proportion of the train set that will be sliced. Only works if `dev_filename` is set to `None`. :param data_dir: The directory in which the train, test and perhaps dev files can be found. :param tasks: Tasks for which the processor shall extract labels from the input data. Usually this includes a single, default task, e.g. text classification. @@ -137,7 +137,7 @@ class Processor(ABC): If None and 0.0 < dev_split < 1.0 the dev set will be a slice of the train set. :param test_filename: The name of the file containing test data. - :param dev_split: The proportion of the train set that will sliced. + :param dev_split: The proportion of the train set that will be sliced. Only works if dev_filename is set to None :param kwargs: placeholder for passing generic parameters :return: An instance of the specified processor. @@ -217,6 +217,7 @@ class Processor(ABC): tokenizer_class=None, tokenizer_args=None, use_fast=True, + max_query_length=64, **kwargs, ): tokenizer_args = tokenizer_args or {} @@ -238,6 +239,7 @@ class Processor(ABC): metric="squad", data_dir="data", doc_stride=doc_stride, + max_query_length=max_query_length, ) elif task_type == "embeddings": processor = InferenceProcessor(tokenizer=tokenizer, max_seq_len=max_seq_len) @@ -396,7 +398,7 @@ class SquadProcessor(Processor): :param dev_filename: The name of the file containing the dev data. If None and 0.0 < dev_split < 1.0 the dev set will be a slice of the train set. :param test_filename: None - :param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None + :param dev_split: The proportion of the train set that will be sliced. Only works if `dev_filename` is set to `None`. :param doc_stride: When the document containing the answer is too long it gets split into part, strided by doc_stride :param max_query_length: Maximum length of the question (in number of subword tokens) :param proxies: proxy configuration to allow downloads of remote datasets. diff --git a/haystack/modeling/infer.py b/haystack/modeling/infer.py index dcc953ee6..a1e08aab4 100644 --- a/haystack/modeling/infer.py +++ b/haystack/modeling/infer.py @@ -130,6 +130,7 @@ class Inferencer: multithreading_rust: bool = True, use_auth_token: Optional[Union[bool, str]] = None, devices: Optional[List[Union[str, torch.device]]] = None, + max_query_length: int = 64, **kwargs, ): """ @@ -178,6 +179,7 @@ class Inferencer: `transformers-cli login` (stored in ~/.huggingface) will be used. Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained + :param max_query_length: Only QA: Maximum length of the question in number of tokens. :return: An instance of the Inferencer. """ if tokenizer_args is None: @@ -228,6 +230,7 @@ class Inferencer: tokenizer_args=tokenizer_args, use_fast=use_fast, use_auth_token=use_auth_token, + max_query_length=max_query_length, **kwargs, ) @@ -241,6 +244,8 @@ class Inferencer: "Please set a lower value for doc_stride (Suggestions: doc_stride=128, max_seq_len=384) " ) processor.doc_stride = doc_stride + if hasattr(processor, "max_query_length"): + processor.max_query_length = max_query_length return cls( model, diff --git a/haystack/modeling/model/adaptive_model.py b/haystack/modeling/model/adaptive_model.py index 3e564a8ae..69605f335 100644 --- a/haystack/modeling/model/adaptive_model.py +++ b/haystack/modeling/model/adaptive_model.py @@ -280,7 +280,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel): * vocab.txt vocab file for language model, turning text to Wordpiece Tokens :param load_dir: Location where the AdaptiveModel is stored. - :param device: To which device we want to sent the model, either torch.device("cpu") or torch.device("cuda"). + :param device: Specifies the device to which you want to send the model, either torch.device("cpu") or torch.device("cuda"). :param strict: Whether to strictly enforce that the keys loaded from saved model match the ones in the PredictionHead (see torch.nn.module.load_state_dict()). :param processor: Processor to populate prediction head with information coming from tasks. diff --git a/haystack/modeling/model/optimization.py b/haystack/modeling/model/optimization.py index becb11600..181601d57 100644 --- a/haystack/modeling/model/optimization.py +++ b/haystack/modeling/model/optimization.py @@ -14,29 +14,6 @@ from haystack.utils.experiment_tracking import Tracker as tracker logger = logging.getLogger(__name__) -try: - from apex import amp # pylint: disable=import-error - - logger.info("apex is available.") - - try: - from apex.parallel import convert_syncbn_model # pylint: disable=import-error - - APEX_PARALLEL_AVAILABLE = True - - logger.info("apex.parallel is available.") - - except AttributeError: - APEX_PARALLEL_AVAILABLE = False - logger.debug("apex.parallel not found, won't use it. See https://nvidia.github.io/apex/parallel.html") - - AMP_AVAILABLE = True - -except ImportError: - AMP_AVAILABLE = False - APEX_PARALLEL_AVAILABLE = False - logger.debug("apex not found, won't use it. See https://nvidia.github.io/apex/") - class WrappedDataParallel(DataParallel): """ @@ -57,7 +34,8 @@ class WrappedDataParallel(DataParallel): class WrappedDDP(DistributedDataParallel): """ A way of adapting attributes of underlying class to distributed mode. Same as in WrappedDataParallel above. - Even when using distributed on a single machine with multiple GPUs, apex can speed up training significantly. + Even when using distributed on a single computer with multiple GPUs, automatic mixed precision can speed up training + significantly. Distributed code must be launched with "python -m torch.distributed.launch --nproc_per_node=1 run_script.py" """ @@ -79,7 +57,7 @@ def initialize_optimizer( distributed: bool = False, grad_acc_steps: int = 1, local_rank: int = -1, - use_amp: Optional[str] = None, + use_amp: bool = False, ): """ Initializes an optimizer, a learning rate scheduler and converts the model if needed (e.g for mixed precision). @@ -91,15 +69,13 @@ def initialize_optimizer( :param n_epochs: number of epochs for training :param device: Which hardware will be used by the optimizer. Either torch.device("cpu") or torch.device("cuda"). :param learning_rate: Learning rate - :param optimizer_opts: Dict to customize the optimizer. Choose any optimizer available from torch.optim, apex.optimizers or + :param optimizer_opts: Dictionary to customize the optimizer. Choose any optimizer available from torch.optim or transformers.optimization by supplying the class name and the parameters for the constructor. Examples: 1) AdamW from Transformers (Default): {"name": "AdamW", "correct_bias": False, "weight_decay": 0.01} 2) SGD from pytorch: {"name": "SGD", "momentum": 0.0} - 3) FusedLAMB from apex: - {"name": "FusedLAMB", "bias_correction": True} :param schedule_opts: Dict to customize the learning rate schedule. Choose any Schedule from Pytorch or Huggingface's Transformers by supplying the class name and the parameters needed by the constructor. @@ -119,20 +95,17 @@ def initialize_optimizer( :param distributed: Whether training on distributed machines :param grad_acc_steps: Number of steps to accumulate gradients for. Helpful to mimic large batch_sizes on small machines. :param local_rank: rank of the machine in a distributed setting - :param use_amp: Optimization level of nvidia's automatic mixed precision (AMP). The higher the level, the faster the model. - Options: - "O0" (Normal FP32 training) - "O1" (Mixed Precision => Recommended) - "O2" (Almost FP16) - "O3" (Pure FP16). - See details on: https://nvidia.github.io/apex/amp.html + :param use_amp: This option is deprecated. Haystack supports only PyTorch automatic mixed precision (AMP). We no longer support the Apex library. This means this function doesn't use `use_amp` any longer + because it's not needed to initialize native Pytorch AMP. If you provide a value, you'll see a warning message. + :return: model, optimizer, scheduler """ - if use_amp and not AMP_AVAILABLE: - raise ImportError( - f"Got use_amp = {use_amp}, but cannot find apex. " - "Please install Apex if you want to make use of automatic mixed precision. " - "https://github.com/NVIDIA/apex" + if isinstance(use_amp, str): + logger.warning( + "Haystack supports only PyTorch automatic mixed precision. We no longer support the Apex library.\n" + "This means that modeling.model.initialize_optimizer no longer uses use_amp since it is not needed\n" + "to initialize native PyTorch automatic mixed precision. For more information, see [Optimization](https://haystack.deepset.ai/guides/optimization).\n" + "In the future provide use_amp=True to use automatic mixed precision." ) if (schedule_opts is not None) and (not isinstance(schedule_opts, dict)): @@ -161,13 +134,13 @@ def initialize_optimizer( schedule_opts["num_training_steps"] = num_train_optimization_steps # Log params - tracker.track_params({"use_amp": use_amp, "num_train_optimization_steps": schedule_opts["num_training_steps"]}) + tracker.track_params({"num_train_optimization_steps": schedule_opts["num_training_steps"]}) - # Get optimizer from pytorch, transformers or apex + # Get optimizer from pytorch or transformers optimizer = _get_optim(model, optimizer_opts) - # Adjust for parallel training + amp - model, optimizer = optimize_model(model, device, local_rank, optimizer, distributed, use_amp) + # Adjust for parallel training + model, optimizer = optimize_model(model, device, local_rank, optimizer, distributed) # Get learning rate schedule - moved below to supress warning scheduler = get_scheduler(optimizer, schedule_opts) @@ -219,7 +192,7 @@ def _get_optim(model, opts: Dict): if weight_decay is not None: optimizable_parameters[0]["weight_decay"] = weight_decay # type: ignore - # Import optimizer by checking in order: torch, transformers, apex and local imports + # Import optimizer by checking in order: torch, transformers and local imports try: optim_constructor = getattr(import_module("torch.optim"), optimizer_name) except AttributeError: @@ -227,17 +200,14 @@ def _get_optim(model, opts: Dict): optim_constructor = getattr(import_module("transformers.optimization"), optimizer_name) except AttributeError: try: - optim_constructor = getattr(import_module("apex.optimizers"), optimizer_name) + # Workaround to allow loading AdamW from transformers + # pytorch > 1.2 has now also a AdamW (but without the option to set bias_correction = False, + # which is done in the original BERT implementation) + optim_constructor = getattr(sys.modules[__name__], optimizer_name) except (AttributeError, ImportError): - try: - # Workaround to allow loading AdamW from transformers - # pytorch > 1.2 has now also a AdamW (but without the option to set bias_correction = False, - # which is done in the original BERT implementation) - optim_constructor = getattr(sys.modules[__name__], optimizer_name) - except (AttributeError, ImportError): - raise AttributeError( - f"Optimizer '{optimizer_name}' not found in 'torch', 'transformers', 'apex' or 'local imports" - ) + raise AttributeError( + f"We couldn't find optimizer '{optimizer_name}' in 'torch', 'transformers' or 'local imports'." + ) return optim_constructor(optimizable_parameters) @@ -298,9 +268,9 @@ def optimize_model( model: "AdaptiveModel", device: torch.device, local_rank: int, - optimizer=None, - distributed: Optional[bool] = False, - use_amp: Optional[str] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + distributed: bool = False, + use_amp: bool = False, ): """ Wraps MultiGPU or distributed usage around a model @@ -310,24 +280,23 @@ def optimize_model( :param device: either torch.device("cpu") or torch.device("cuda"). Get the device from `initialize_device_settings()` :param distributed: Whether training on distributed machines :param local_rank: rank of the machine in a distributed setting - :param use_amp: Optimization level of nvidia's automatic mixed precision (AMP). The higher the level, the faster the model. - Options: - "O0" (Normal FP32 training) - "O1" (Mixed Precision => Recommended) - "O2" (Almost FP16) - "O3" (Pure FP16). - See details on: https://nvidia.github.io/apex/amp.html + :param optimizer: torch optimizer + :param use_amp: This option is deprecated. Haystack supports only PyTorch automatic mixed precision (AMP). We no longer support the Apex library. This means this function no longer uses `use_amp` + because it's not needed to initialize native Pytorch AMP. If you provide a value, you'll see a warning message. + :return: model, optimizer """ - model, optimizer = _init_amp(model, device, optimizer, use_amp) + if isinstance(use_amp, str): + logger.warning( + "Haystack supports only PyTorch automatic mixed precision. We no longer support the Apex library.\n" + "This means that modeling.model.initialize_optimizer no longer uses use_amp since it's not needed\n" + "to initialize native PyTorch automatic mixed precision. For more information, see [Optimization](https://haystack.deepset.ai/guides/optimization).\n" + "In the future, set `use_amp=True` to use automatic mixed precision." + ) + + model = model.to(device) if distributed: - if APEX_PARALLEL_AVAILABLE: - model = convert_syncbn_model(model) - logger.info("Multi-GPU Training via DistributedDataParallel and apex.parallel") - else: - logger.info("Multi-GPU Training via DistributedDataParallel") - # for some models DistributedDataParallel might complain about parameters # not contributing to loss. find_used_parameters remedies that. model = WrappedDDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) @@ -337,16 +306,3 @@ def optimize_model( logger.info("Multi-GPU Training via DataParallel") return model, optimizer - - -def _init_amp(model, device, optimizer=None, use_amp=None): - model = model.to(device) - if use_amp and optimizer: - if AMP_AVAILABLE: - model, optimizer = amp.initialize(model, optimizer, opt_level=use_amp) - else: - logger.warning( - f"Can't find AMP although you specificed to use amp with level {use_amp}. Will continue without AMP ..." - ) - - return model, optimizer diff --git a/haystack/modeling/model/predictions.py b/haystack/modeling/model/predictions.py index c9d208f84..10c684827 100644 --- a/haystack/modeling/model/predictions.py +++ b/haystack/modeling/model/predictions.py @@ -297,10 +297,10 @@ class QAPred(Pred): def _answers_to_json(self, ext_id, squad=False) -> List[Dict]: """ - Convert all answers into a json format + Convert all answers into a json format. - :param id: ID of the question document pair - :param squad: If True, no_answers are represented by the empty string instead of "no_answer" + :param ext_id: ID of the question document pair. + :param squad: If True, no_answers are represented by the empty string instead of "no_answer". """ ret = [] diff --git a/haystack/modeling/training/base.py b/haystack/modeling/training/base.py index 1d6f84d0f..f72afe3f7 100644 --- a/haystack/modeling/training/base.py +++ b/haystack/modeling/training/base.py @@ -23,13 +23,6 @@ from haystack.modeling.utils import GracefulKiller from haystack.utils.experiment_tracking import Tracker as tracker from haystack.utils.early_stopping import EarlyStopping -try: - from apex import amp - - AMP_AVAILABLE = True -except ImportError: - AMP_AVAILABLE = False - logger = logging.getLogger(__name__) @@ -51,7 +44,7 @@ class Trainer: lr_schedule=None, evaluate_every: int = 100, eval_report: bool = True, - use_amp: Optional[str] = None, + use_amp: bool = False, grad_acc_steps: int = 1, local_rank: int = -1, early_stopping: Optional[EarlyStopping] = None, @@ -77,8 +70,10 @@ class Trainer: :param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer :param evaluate_every: Perform dev set evaluation after this many steps of training. :param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating - :param use_amp: Whether to use automatic mixed precision with Apex. One of the optimization levels must be chosen. - "O1" is recommended in almost all cases. + :param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve + training speed and reduce GPU memory usage. + For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] + and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. :param grad_acc_steps: Number of training steps for which the gradients should be accumulated. Useful to achieve larger effective batch sizes that would not fit in GPU memory. :param local_rank: Local rank of process when distributed training via DDP is used. @@ -88,10 +83,10 @@ class Trainer: :param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where a SIGTERM notifies to save the training state and subsequently the instance is terminated. - :param checkpoint_every: save a train checkpoint after this many steps of training. - :param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual + :param checkpoint_every: Save a training checkpoint after this many steps of training. + :param checkpoint_root_dir: The directory Path where all training checkpoints are saved. For each individual checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created. - :param checkpoints_to_keep: maximum number of train checkpoints to save. + :param checkpoints_to_keep: The maximum number of training checkpoints to save. :param from_epoch: the epoch number to start the training from. In the case when training resumes from a saved checkpoint, it is used to fast-forward training to the last epoch in the checkpoint. :param from_step: the step number to start the training from. In the case when training resumes from a saved @@ -101,16 +96,30 @@ class Trainer: :param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments) :param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable """ + amp_mapping = {"O0": False, "O1": True, "O2": True, "O3": True} self.model = model self.data_silo = data_silo self.epochs = int(epochs) + if isinstance(use_amp, str): + if use_amp in amp_mapping: + logger.warning( + "The Trainer only supports native PyTorch automatic mixed precision and no longer supports the Apex library.\n" + f"Because you provided Apex optimization level {use_amp}, automatic mixed precision was set to {amp_mapping[use_amp]}.\n" + "In the future, set `use_amp=True` to turn on automatic mixed precision." + ) + use_amp = amp_mapping[use_amp] + else: + raise Exception( + f"use_amp value {use_amp} is not supported. Please provide use_amp=True to turn on automatic mixed precision." + ) + self.use_amp = use_amp self.optimizer = optimizer + self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp) self.evaluate_every = evaluate_every self.eval_report = eval_report self.evaluator_test = evaluator_test self.n_gpu = n_gpu self.grad_acc_steps = grad_acc_steps - self.use_amp = use_amp self.lr_schedule = lr_schedule self.device = device self.local_rank = local_rank @@ -122,12 +131,6 @@ class Trainer: self.max_grad_norm = max_grad_norm self.test_result = None - if use_amp and not AMP_AVAILABLE: - raise ImportError( - f"Got use_amp = {use_amp}, but cannot find apex. " - "Please install Apex if you want to make use of automatic mixed precision. " - "https://github.com/NVIDIA/apex" - ) self.checkpoint_on_sigterm = checkpoint_on_sigterm if checkpoint_on_sigterm: self.sigterm_handler = GracefulKiller() # type: Optional[GracefulKiller] @@ -203,7 +206,7 @@ class Trainer: # Only for distributed training: we need to ensure that all ranks still have a batch left for training if self.local_rank != -1: - if not self._all_ranks_have_data(has_data=1, step=step): + if not self._all_ranks_have_data(has_data=True, step=step): early_break = True break @@ -297,47 +300,44 @@ class Trainer: else: module = self.model - if isinstance(module, AdaptiveModel): - logits = self.model.forward( - input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"] - ) + with torch.cuda.amp.autocast(enabled=self.use_amp): + if isinstance(module, AdaptiveModel): + logits = self.model.forward( + input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"] + ) - elif isinstance(module, BiAdaptiveModel): - logits = self.model.forward( - query_input_ids=batch["query_input_ids"], - query_segment_ids=batch["query_segment_ids"], - query_attention_mask=batch["query_attention_mask"], - passage_input_ids=batch["passage_input_ids"], - passage_segment_ids=batch["passage_segment_ids"], - passage_attention_mask=batch["passage_attention_mask"], - ) + elif isinstance(module, BiAdaptiveModel): + logits = self.model.forward( + query_input_ids=batch["query_input_ids"], + query_segment_ids=batch["query_segment_ids"], + query_attention_mask=batch["query_attention_mask"], + passage_input_ids=batch["passage_input_ids"], + passage_segment_ids=batch["passage_segment_ids"], + passage_attention_mask=batch["passage_attention_mask"], + ) - else: - logits = self.model.forward(**batch) + else: + logits = self.model.forward(**batch) - per_sample_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch) - return self.backward_propagate(per_sample_loss, step) + per_sample_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch) + loss = self.adjust_loss(per_sample_loss) + return self.backward_propagate(loss, step) def backward_propagate(self, loss: torch.Tensor, step: int): - loss = self.adjust_loss(loss) if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0]: if self.local_rank in [-1, 0]: tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step) if self.log_learning_rate: tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step) - if self.use_amp: - with amp.scale_loss(loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + + self.scaler.scale(loss).backward() if step % self.grad_acc_steps == 0: if self.max_grad_norm is not None: - if self.use_amp: - torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.max_grad_norm) - else: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) - self.optimizer.step() + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.scaler.step(self.optimizer) + self.scaler.update() self.optimizer.zero_grad() if self.lr_schedule: self.lr_schedule.step() @@ -350,7 +350,7 @@ class Trainer: return loss def log_params(self): - params = {"epochs": self.epochs, "n_gpu": self.n_gpu, "device": self.device} + params = {"epochs": self.epochs, "n_gpu": self.n_gpu, "device": self.device, "use_amp": self.use_amp} tracker.track_params(params) @classmethod @@ -545,6 +545,7 @@ class Trainer: "log_learning_rate": self.log_learning_rate, "log_loss_every": self.log_loss_every, "disable_tqdm": self.disable_tqdm, + "use_amp": self.use_amp, } return state_dict @@ -580,7 +581,7 @@ class Trainer: class DistillationTrainer(Trainer): """ This trainer uses the teacher logits from DistillationDataSilo - to compute a distillation loss in addtion to the loss based on the labels. + to compute a distillation loss in addition to the loss based on the labels. **Example** ```python @@ -608,7 +609,7 @@ class DistillationTrainer(Trainer): lr_schedule: Optional[_LRScheduler] = None, evaluate_every: int = 100, eval_report: bool = True, - use_amp: Optional[str] = None, + use_amp: bool = False, grad_acc_steps: int = 1, local_rank: int = -1, early_stopping: Optional[EarlyStopping] = None, @@ -631,7 +632,6 @@ class DistillationTrainer(Trainer): """ :param optimizer: An optimizer object that determines the learning strategy to be used during training :param model: The model to be trained - :param teacher_model: The teacher model used for distillation :param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders :param epochs: How many times the training procedure will loop through the train dataset :param n_gpu: The number of gpus available for training and evaluation. @@ -639,8 +639,10 @@ class DistillationTrainer(Trainer): :param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer :param evaluate_every: Perform dev set evaluation after this many steps of training. :param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating - :param use_amp: Whether to use automatic mixed precision with Apex. One of the optimization levels must be chosen. - "O1" is recommended in almost all cases. + :param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve + training speed and reduce GPU memory usage. + For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] + and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. :param grad_acc_steps: Number of training steps for which the gradients should be accumulated. Useful to achieve larger effective batch sizes that would not fit in GPU memory. :param local_rank: Local rank of process when distributed training via DDP is used. @@ -709,23 +711,23 @@ class DistillationTrainer(Trainer): keys = list(batch.keys()) keys = [key for key in keys if key.startswith("teacher_output")] teacher_logits = [batch.pop(key) for key in keys] - - logits = self.model.forward( - input_ids=batch.get("input_ids"), - segment_ids=batch.get("segment_ids"), - padding_mask=batch.get("padding_mask"), - output_hidden_states=batch.get("output_hidden_states"), - output_attentions=batch.get("output_attentions"), - ) - - student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch) - distillation_loss = self.distillation_loss_fn( - student_logits=logits[0] / self.temperature, teacher_logits=teacher_logits[0] / self.temperature - ) - combined_loss = distillation_loss * self.distillation_loss_weight * (self.temperature**2) + student_loss * ( - 1 - self.distillation_loss_weight - ) - return self.backward_propagate(combined_loss, step) + with torch.cuda.amp.autocast(enabled=self.use_amp): + logits = self.model.forward( + input_ids=batch.get("input_ids"), + segment_ids=batch.get("segment_ids"), + padding_mask=batch.get("padding_mask"), + output_hidden_states=batch.get("output_hidden_states"), + output_attentions=batch.get("output_attentions"), + ) + student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch) + distillation_loss = self.distillation_loss_fn( + student_logits=logits[0] / self.temperature, teacher_logits=teacher_logits[0] / self.temperature + ) + combined_loss = distillation_loss * self.distillation_loss_weight * ( + self.temperature**2 + ) + student_loss * (1 - self.distillation_loss_weight) + loss = self.adjust_loss(combined_loss) + return self.backward_propagate(loss, step) class TinyBERTDistillationTrainer(Trainer): @@ -761,7 +763,7 @@ class TinyBERTDistillationTrainer(Trainer): lr_schedule: Optional[_LRScheduler] = None, evaluate_every: int = 100, eval_report: bool = True, - use_amp: Optional[str] = None, + use_amp: bool = False, grad_acc_steps: int = 1, local_rank: int = -1, early_stopping: Optional[EarlyStopping] = None, @@ -789,8 +791,10 @@ class TinyBERTDistillationTrainer(Trainer): :param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer :param evaluate_every: Perform dev set evaluation after this many steps of training. :param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating - :param use_amp: Whether to use automatic mixed precision with Apex. One of the optimization levels must be chosen. - "O1" is recommended in almost all cases. + :param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve + training speed and reduce GPU memory usage. + For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] + and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. :param grad_acc_steps: Number of training steps for which the gradients should be accumulated. Useful to achieve larger effective batch sizes that would not fit in GPU memory. :param local_rank: Local rank of process when distributed training via DDP is used. @@ -849,16 +853,16 @@ class TinyBERTDistillationTrainer(Trainer): self.loss = DataParallel(self.loss).to(device) def compute_loss(self, batch: dict, step: int) -> torch.Tensor: - return self.backward_propagate( - torch.sum( + with torch.cuda.amp.autocast(enabled=self.use_amp): + loss = torch.sum( self.loss( input_ids=batch.get("input_ids"), segment_ids=batch.get("segment_ids"), padding_mask=batch.get("padding_mask"), ) - ), - step, - ) + ) + loss = self.adjust_loss(loss) + return self.backward_propagate(loss, step) class DistillationLoss(Module): diff --git a/haystack/modeling/utils.py b/haystack/modeling/utils.py index 3dba128c5..456ccf44f 100644 --- a/haystack/modeling/utils.py +++ b/haystack/modeling/utils.py @@ -62,7 +62,7 @@ def set_all_seeds(seed: int, deterministic_cudnn: bool = False) -> None: but might slow down your training (see https://pytorch.org/docs/stable/notes/randomness.html#cudnn) ! :param seed:number to use as seed - :param deterministic_torch: Enable for full reproducibility when using CUDA. Caution: might slow down training. + :param deterministic_cudnn: Enable for full reproducibility when using CUDA. Caution: might slow down training. """ random.seed(seed) np.random.seed(seed) diff --git a/haystack/nodes/evaluator/evaluator.py b/haystack/nodes/evaluator/evaluator.py index 2b4c245eb..af6ba6755 100644 --- a/haystack/nodes/evaluator/evaluator.py +++ b/haystack/nodes/evaluator/evaluator.py @@ -407,7 +407,7 @@ def semantic_answer_similarity( Computes Transformer-based similarity of predicted answer to gold labels to derive a more meaningful metric than EM or F1. Returns per QA pair a) the similarity of the most likely prediction (top 1) to all available gold labels b) the highest similarity of all predictions to gold labels - c) a matrix consisting of the similarities of all the predicitions compared to all gold labels + c) a matrix consisting of the similarities of all the predictions compared to all gold labels :param predictions: Predicted answers as list of multiple preds per question :param gold_labels: Labels as list of multiple possible answers per question diff --git a/haystack/nodes/reader/farm.py b/haystack/nodes/reader/farm.py index b82c40205..9c396d93a 100644 --- a/haystack/nodes/reader/farm.py +++ b/haystack/nodes/reader/farm.py @@ -68,6 +68,7 @@ class FARMReader(BaseReader): local_files_only=False, force_download=False, use_auth_token: Optional[Union[str, bool]] = None, + max_query_length: int = 64, ): """ @@ -128,6 +129,7 @@ class FARMReader(BaseReader): `transformers-cli login` (stored in ~/.huggingface) will be used. Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained + :param max_query_length: Maximum length of the question in number of tokens. """ super().__init__() @@ -151,6 +153,7 @@ class FARMReader(BaseReader): force_download=force_download, devices=self.devices, use_auth_token=use_auth_token, + max_query_length=max_query_length, ) self.inferencer.model.prediction_heads[0].context_window_size = context_window_size self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost @@ -159,6 +162,8 @@ class FARMReader(BaseReader): self.inferencer.model.prediction_heads[0].duplicate_filtering = duplicate_filtering self.inferencer.model.prediction_heads[0].use_confidence_scores_for_ranking = use_confidence_scores self.max_seq_len = max_seq_len + self.doc_stride = doc_stride + self.max_query_length = max_query_length self.progress_bar = progress_bar self.use_confidence_scores = use_confidence_scores self.confidence_threshold = confidence_threshold @@ -181,7 +186,7 @@ class FARMReader(BaseReader): evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, - use_amp: Optional[str] = None, + use_amp: bool = False, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, @@ -196,6 +201,9 @@ class FARMReader(BaseReader): processor: Optional[Processor] = None, grad_acc_steps: int = 1, early_stopping: Optional[EarlyStopping] = None, + distributed: bool = False, + doc_stride: Optional[int] = None, + max_query_length: Optional[int] = None, ): if dev_filename: dev_split = 0 @@ -211,6 +219,10 @@ class FARMReader(BaseReader): devices = self.devices if max_seq_len is None: max_seq_len = self.max_seq_len + if doc_stride is None: + doc_stride = self.doc_stride + if max_query_length is None: + max_query_length = self.max_query_length devices, n_gpu = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False) @@ -226,6 +238,8 @@ class FARMReader(BaseReader): processor = SquadProcessor( tokenizer=self.inferencer.processor.tokenizer, max_seq_len=max_seq_len, + max_query_length=max_query_length, + doc_stride=doc_stride, label_list=label_list, metric=metric, train_filename=train_filename, @@ -247,7 +261,7 @@ class FARMReader(BaseReader): device=devices[0], processor=processor, batch_size=batch_size, - distributed=False, + distributed=distributed, max_processes=num_processes, caching=caching, cache_path=cache_path, @@ -256,7 +270,7 @@ class FARMReader(BaseReader): data_silo = DataSilo( processor=processor, batch_size=batch_size, - distributed=False, + distributed=distributed, max_processes=num_processes, caching=caching, cache_path=cache_path, @@ -265,14 +279,13 @@ class FARMReader(BaseReader): # 3. Create an optimizer and pass the already initialized model model, optimizer, lr_schedule = initialize_optimizer( model=self.inferencer.model, - # model=self.inferencer.model, learning_rate=learning_rate, schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion}, n_batches=len(data_silo.loaders["train"]), n_epochs=n_epochs, device=devices[0], - use_amp=use_amp, grad_acc_steps=grad_acc_steps, + distributed=distributed, ) # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time if tinybert: @@ -360,7 +373,7 @@ class FARMReader(BaseReader): evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, - use_amp: Optional[str] = None, + use_amp: bool = False, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, @@ -368,6 +381,7 @@ class FARMReader(BaseReader): cache_path: Path = Path("cache/data_silo"), grad_acc_steps: int = 1, early_stopping: Optional[EarlyStopping] = None, + max_query_length: Optional[int] = None, ): """ Fine-tune a model on a QA dataset. Options: @@ -404,14 +418,10 @@ class FARMReader(BaseReader): :param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing. Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. Set to None to use all CPU cores minus one. - :param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. - Available options: - None (Don't use AMP) - "O0" (Normal FP32 training) - "O1" (Mixed Precision => Recommended) - "O2" (Almost FP16) - "O3" (Pure FP16). - See details on: https://nvidia.github.io/apex/amp.html + :param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve + training speed and reduce GPU memory usage. + For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] + and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. :param checkpoint_root_dir: The Path of a directory where all train checkpoints are saved. For each individual checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created. :param checkpoint_every: Save a train checkpoint after this many steps of training. @@ -420,6 +430,7 @@ class FARMReader(BaseReader): :param cache_path: The Path to cache the preprocessed dataset. :param grad_acc_steps: The number of steps to accumulate gradients for before performing a backward pass. :param early_stopping: An initialized EarlyStopping object to control early stopping and saving of the best models. + :param max_query_length: Maximum length of the question in number of tokens. :return: None """ return self._training_procedure( @@ -446,6 +457,8 @@ class FARMReader(BaseReader): cache_path=cache_path, grad_acc_steps=grad_acc_steps, early_stopping=early_stopping, + max_query_length=max_query_length, + distributed=False, ) def distil_prediction_layer_from( @@ -467,7 +480,7 @@ class FARMReader(BaseReader): evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, - use_amp: Optional[str] = None, + use_amp: bool = False, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, @@ -510,7 +523,7 @@ class FARMReader(BaseReader): [torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices parameter is not used and a single cpu device is used for inference. :param student_batch_size: Number of samples the student model receives in one batch for training - :param student_batch_size: Number of samples the teacher model receives in one batch for distillation + :param teacher_batch_size: Number of samples the teacher model receives in one batch for distillation :param n_epochs: Number of iterations on the whole training data set :param learning_rate: Learning rate of the optimizer :param max_seq_len: Maximum text length (in tokens). Everything longer gets cut down. @@ -522,14 +535,10 @@ class FARMReader(BaseReader): :param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing. Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. Set to None to use all CPU cores minus one. - :param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. - Available options: - None (Don't use AMP) - "O0" (Normal FP32 training) - "O1" (Mixed Precision => Recommended) - "O2" (Almost FP16) - "O3" (Pure FP16). - See details on: https://nvidia.github.io/apex/amp.html + :param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve + training speed and reduce GPU memory usage. + For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] + and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. :param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created. :param checkpoint_every: save a train checkpoint after this many steps of training. @@ -577,6 +586,7 @@ class FARMReader(BaseReader): temperature=temperature, grad_acc_steps=grad_acc_steps, early_stopping=early_stopping, + distributed=False, ) def distil_intermediate_layers_from( @@ -597,7 +607,7 @@ class FARMReader(BaseReader): evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, - use_amp: Optional[str] = None, + use_amp: bool = False, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, @@ -635,8 +645,7 @@ class FARMReader(BaseReader): A list containing torch device objects and/or strings is supported (For example [torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices parameter is not used and a single cpu device is used for inference. - :param student_batch_size: Number of samples the student model receives in one batch for training - :param student_batch_size: Number of samples the teacher model receives in one batch for distillation + :param batch_size: Number of samples the student model and teacher model receives in one batch for training :param n_epochs: Number of iterations on the whole training data set :param learning_rate: Learning rate of the optimizer :param max_seq_len: Maximum text length (in tokens). Everything longer gets cut down. @@ -648,21 +657,16 @@ class FARMReader(BaseReader): :param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing. Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. Set to None to use all CPU cores minus one. - :param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. - Available options: - None (Don't use AMP) - "O0" (Normal FP32 training) - "O1" (Mixed Precision => Recommended) - "O2" (Almost FP16) - "O3" (Pure FP16). - See details on: https://nvidia.github.io/apex/amp.html + :param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve + training speed and reduce GPU memory usage. + For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] + and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. :param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created. :param checkpoint_every: save a train checkpoint after this many steps of training. :param checkpoints_to_keep: maximum number of train checkpoints to save. :param caching: whether or not to use caching for preprocessed dataset and teacher logits :param cache_path: Path to cache the preprocessed dataset and teacher logits - :param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important. :param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits) :param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model. :param processor: The processor to use for preprocessing. If None, the default SquadProcessor is used. @@ -700,6 +704,7 @@ class FARMReader(BaseReader): processor=processor, grad_acc_steps=grad_acc_steps, early_stopping=early_stopping, + distributed=False, ) def update_parameters( @@ -1297,7 +1302,7 @@ class FARMReader(BaseReader): ``` :param question: Question string - :param documents: List of documents as string type + :param texts: A list of Document texts as a string type :param top_k: The maximum number of answers to return :return: Dict containing question and answers """ diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 125fda861..ea44f5885 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -593,7 +593,7 @@ class DensePassageRetriever(DenseRetriever): weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, - use_amp: Optional[str] = None, + use_amp: bool = False, optimizer_name: str = "AdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr", @@ -628,12 +628,10 @@ class DensePassageRetriever(DenseRetriever): :param epsilon: epsilon parameter of optimizer :param weight_decay: weight decay parameter of optimizer :param grad_acc_steps: number of steps to accumulate gradient over before back-propagation is done - :param use_amp: Whether to use automatic mixed precision (AMP) or not. The options are: - "O0" (FP32) - "O1" (Mixed Precision) - "O2" (Almost FP16) - "O3" (Pure FP16). - For more information, refer to: https://nvidia.github.io/apex/amp.html + :param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve + training speed and reduce GPU memory usage. + For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] + and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. :param optimizer_name: what optimizer to use (default: AdamW) :param num_warmup_steps: number of warmup steps :param optimizer_correct_bias: Whether to correct bias in optimizer @@ -687,7 +685,6 @@ class DensePassageRetriever(DenseRetriever): n_epochs=n_epochs, grad_acc_steps=grad_acc_steps, device=self.devices[0], # Only use first device while multi-gpu training is not implemented - use_amp=use_amp, ) # 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time @@ -1236,7 +1233,7 @@ class TableTextRetriever(DenseRetriever): weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, - use_amp: Optional[str] = None, + use_amp: bool = False, optimizer_name: str = "AdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/mm_retrieval", @@ -1273,12 +1270,10 @@ class TableTextRetriever(DenseRetriever): :param epsilon: Epsilon parameter of optimizer. :param weight_decay: Weight decay parameter of optimizer. :param grad_acc_steps: Number of steps to accumulate gradient over before back-propagation is done. - :param use_amp: Whether to use automatic mixed precision (AMP) or not. The options are: - "O0" (FP32) - "O1" (Mixed Precision) - "O2" (Almost FP16) - "O3" (Pure FP16). - For more information, refer to: https://nvidia.github.io/apex/amp.html + :param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve + training speed and reduce GPU memory usage. + For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization] + and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html]. :param optimizer_name: What optimizer to use (default: TransformersAdamW). :param num_warmup_steps: Number of warmup steps. :param optimizer_correct_bias: Whether to correct bias in optimizer. @@ -1327,7 +1322,6 @@ class TableTextRetriever(DenseRetriever): n_epochs=n_epochs, grad_acc_steps=grad_acc_steps, device=self.devices[0], # Only use first device while multi-gpu training is not implemented - use_amp=use_amp, ) # 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time diff --git a/haystack/schema.py b/haystack/schema.py index d8c38c10f..0efd7a43e 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -1269,7 +1269,7 @@ class EvaluationResult: Answer metrics are: - exact_match (Did the query exactly return any gold answer? -> 1.0 or 0.0) - f1 (How well does the best matching returned results overlap with any gold answer on token basis?) - - sas if a SAS model has bin provided during during pipeline.eval() (How semantically similar is the prediction to the gold answers?) + - sas if a SAS model has been provided during pipeline.eval() (How semantically similar is the prediction to the gold answers?) """ multilabel_ids = answers["multilabel_id"].unique() # simulate top k retriever diff --git a/test/nodes/test_reader.py b/test/nodes/test_reader.py index ca7af5f44..981599d72 100644 --- a/test/nodes/test_reader.py +++ b/test/nodes/test_reader.py @@ -12,6 +12,8 @@ from haystack.schema import Document, Answer, Label, MultiLabel, Span from haystack.nodes.reader.base import BaseReader from haystack.nodes import FARMReader, TransformersReader +from ..conftest import SAMPLES_PATH + # TODO Fix bug in test_no_answer_output when using # @pytest.fixture(params=["farm", "transformers"]) @@ -405,3 +407,31 @@ def test_no_answer_reader_skips_empty_documents(no_answer_reader): ) assert predictions["answers"][0][0].answer == "" # Return no_answer for 1st query as document is empty assert predictions["answers"][1][1].answer == "Carla" # answer given for 2nd query as usual + + +def test_reader_training(tmp_path): + max_seq_len = 16 + max_query_length = 8 + reader = FARMReader( + model_name_or_path="deepset/tinyroberta-squad2", + use_gpu=False, + num_processes=0, + max_seq_len=max_seq_len, + doc_stride=2, + max_query_length=max_query_length, + ) + + save_dir = f"{tmp_path}/test_dpr_training" + reader.train( + data_dir=str(SAMPLES_PATH / "squad"), + train_filename="tiny.json", + dev_filename="tiny.json", + test_filename="tiny.json", + n_epochs=1, + batch_size=1, + grad_acc_steps=1, + save_dir=save_dir, + evaluate_every=2, + max_seq_len=max_seq_len, + max_query_length=max_query_length, + )