mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-28 07:28:55 +00:00
add Const lr
This commit is contained in:
parent
07633eb850
commit
8a28962cd7
@ -34,10 +34,12 @@ Optimizer:
|
||||
beta2: 0.999
|
||||
clip_norm: 10
|
||||
lr:
|
||||
name: Piecewise
|
||||
values: [0.000005, 0.00005]
|
||||
decay_epochs: [10]
|
||||
warmup_epoch: 0
|
||||
# name: Piecewise
|
||||
# values: [0.000005, 0.00005]
|
||||
# decay_epochs: [10]
|
||||
# warmup_epoch: 0
|
||||
learning_rate: 0.00005
|
||||
warmup_epoch: 10
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 0.00000
|
||||
|
||||
@ -34,10 +34,8 @@ Optimizer:
|
||||
beta2: 0.999
|
||||
clip_norm: 10
|
||||
lr:
|
||||
name: Piecewise
|
||||
values: [0.000005, 0.00005]
|
||||
decay_epochs: [10]
|
||||
warmup_epoch: 0
|
||||
learning_rate: 0.00005
|
||||
warmup_epoch: 10
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 0.00000
|
||||
|
||||
@ -25,11 +25,8 @@ __all__ = ['build_optimizer']
|
||||
def build_lr_scheduler(lr_config, epochs, step_each_epoch):
|
||||
from . import learning_rate
|
||||
lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
|
||||
if 'name' in lr_config:
|
||||
lr_name = lr_config.pop('name')
|
||||
lr = getattr(learning_rate, lr_name)(**lr_config)()
|
||||
else:
|
||||
lr = lr_config['learning_rate']
|
||||
lr_name = lr_config.pop('name', 'Const')
|
||||
lr = getattr(learning_rate, lr_name)(**lr_config)()
|
||||
return lr
|
||||
|
||||
|
||||
|
||||
@ -275,4 +275,36 @@ class OneCycle(object):
|
||||
start_lr=0.0,
|
||||
end_lr=self.max_lr,
|
||||
last_epoch=self.last_epoch)
|
||||
return learning_rate
|
||||
return learning_rate
|
||||
|
||||
|
||||
class Const(object):
|
||||
"""
|
||||
Const learning rate decay
|
||||
Args:
|
||||
learning_rate(float): initial learning rate
|
||||
step_each_epoch(int): steps each epoch
|
||||
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
step_each_epoch,
|
||||
warmup_epoch=0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super(Const, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.last_epoch = last_epoch
|
||||
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = self.learning_rate
|
||||
if self.warmup_epoch > 0:
|
||||
learning_rate = lr.LinearWarmup(
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=self.warmup_epoch,
|
||||
start_lr=0.0,
|
||||
end_lr=self.learning_rate,
|
||||
last_epoch=self.last_epoch)
|
||||
return learning_rate
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user