mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-10-03 12:06:36 +00:00
support eval pre epoch (#11003)
This commit is contained in:
parent
e49e491417
commit
4ba32bc91c
@ -185,6 +185,7 @@ def train(config,
|
|||||||
eval_class,
|
eval_class,
|
||||||
pre_best_model_dict,
|
pre_best_model_dict,
|
||||||
logger,
|
logger,
|
||||||
|
step_pre_epoch,
|
||||||
log_writer=None,
|
log_writer=None,
|
||||||
scaler=None,
|
scaler=None,
|
||||||
amp_level='O2',
|
amp_level='O2',
|
||||||
@ -198,6 +199,7 @@ def train(config,
|
|||||||
epoch_num = config['Global']['epoch_num']
|
epoch_num = config['Global']['epoch_num']
|
||||||
print_batch_step = config['Global']['print_batch_step']
|
print_batch_step = config['Global']['print_batch_step']
|
||||||
eval_batch_step = config['Global']['eval_batch_step']
|
eval_batch_step = config['Global']['eval_batch_step']
|
||||||
|
eval_batch_epoch = config['Global'].get('eval_batch_epoch', None)
|
||||||
profiler_options = config['profiler_options']
|
profiler_options = config['profiler_options']
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
@ -205,8 +207,9 @@ def train(config,
|
|||||||
global_step = pre_best_model_dict['global_step']
|
global_step = pre_best_model_dict['global_step']
|
||||||
start_eval_step = 0
|
start_eval_step = 0
|
||||||
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
||||||
start_eval_step = eval_batch_step[0]
|
start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0
|
||||||
eval_batch_step = eval_batch_step[1]
|
eval_batch_step = eval_batch_step[
|
||||||
|
1] if not eval_batch_epoch else step_pre_epoch * eval_batch_epoch
|
||||||
if len(valid_dataloader) == 0:
|
if len(valid_dataloader) == 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
'No Images in eval dataset, evaluation during training ' \
|
'No Images in eval dataset, evaluation during training ' \
|
||||||
|
@ -61,9 +61,11 @@ def main(config, device, logger, vdl_writer, seed):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if config['Eval']:
|
if config['Eval']:
|
||||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger, seed)
|
valid_dataloader = build_dataloader(config, 'Eval', device, logger,
|
||||||
|
seed)
|
||||||
else:
|
else:
|
||||||
valid_dataloader = None
|
valid_dataloader = None
|
||||||
|
step_pre_epoch = len(train_dataloader)
|
||||||
|
|
||||||
# build post process
|
# build post process
|
||||||
post_process_class = build_post_process(config['PostProcess'],
|
post_process_class = build_post_process(config['PostProcess'],
|
||||||
@ -93,7 +95,8 @@ def main(config, device, logger, vdl_writer, seed):
|
|||||||
'DistillationSARLoss'][
|
'DistillationSARLoss'][
|
||||||
'ignore_index'] = char_num + 1
|
'ignore_index'] = char_num + 1
|
||||||
out_channels_list['SARLabelDecode'] = char_num + 2
|
out_channels_list['SARLabelDecode'] = char_num + 2
|
||||||
elif any('DistillationNRTRLoss' in d for d in config['Loss']['loss_config_list']):
|
elif any('DistillationNRTRLoss' in d
|
||||||
|
for d in config['Loss']['loss_config_list']):
|
||||||
out_channels_list['NRTRLabelDecode'] = char_num + 3
|
out_channels_list['NRTRLabelDecode'] = char_num + 3
|
||||||
|
|
||||||
config['Architecture']['Models'][key]['Head'][
|
config['Architecture']['Models'][key]['Head'][
|
||||||
@ -196,9 +199,9 @@ def main(config, device, logger, vdl_writer, seed):
|
|||||||
# start train
|
# start train
|
||||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||||
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
|
eval_class, pre_best_model_dict, logger, step_pre_epoch,
|
||||||
amp_level, amp_custom_black_list, amp_custom_white_list,
|
vdl_writer, scaler, amp_level, amp_custom_black_list,
|
||||||
amp_dtype)
|
amp_custom_white_list, amp_dtype)
|
||||||
|
|
||||||
|
|
||||||
def test_reader(config, device, logger):
|
def test_reader(config, device, logger):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user