mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-26 21:24:27 +00:00
parent
cf03afb10a
commit
b18b656633
@ -66,7 +66,6 @@ def amp_scaler(config):
|
||||
if "AMP" in config and config["AMP"]["use_amp"] is True:
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
"FLAGS_cudnn_batchnorm_spatial_persistent": 1,
|
||||
"FLAGS_max_inplace_grad_add": 8,
|
||||
}
|
||||
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
scale_loss = config["AMP"].get("scale_loss", 1.0)
|
||||
|
@ -131,7 +131,6 @@ def main():
|
||||
if use_amp:
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
"FLAGS_cudnn_batchnorm_spatial_persistent": 1,
|
||||
"FLAGS_max_inplace_grad_add": 8,
|
||||
}
|
||||
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
scale_loss = config["Global"].get("scale_loss", 1.0)
|
||||
|
@ -181,9 +181,7 @@ def main(config, device, logger, vdl_writer, seed):
|
||||
except:
|
||||
pass
|
||||
if use_amp:
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
"FLAGS_max_inplace_grad_add": 8,
|
||||
}
|
||||
AMP_RELATED_FLAGS_SETTING = {}
|
||||
if paddle.is_compiled_with_cuda():
|
||||
AMP_RELATED_FLAGS_SETTING.update(
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user