mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-07 05:13:29 +00:00
add db++
This commit is contained in:
parent
1315cdfc86
commit
04e7104194
@ -18,7 +18,7 @@ Global:
|
|||||||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||||
Architecture:
|
Architecture:
|
||||||
model_type: det
|
model_type: det
|
||||||
algorithm: DB
|
algorithm: DB++
|
||||||
Transform: null
|
Transform: null
|
||||||
Backbone:
|
Backbone:
|
||||||
name: ResNet
|
name: ResNet
|
||||||
|
|||||||
@ -18,7 +18,7 @@ Global:
|
|||||||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||||
Architecture:
|
Architecture:
|
||||||
model_type: det
|
model_type: det
|
||||||
algorithm: DB
|
algorithm: DB++
|
||||||
Transform: null
|
Transform: null
|
||||||
Backbone:
|
Backbone:
|
||||||
name: ResNet
|
name: ResNet
|
||||||
|
|||||||
@ -67,6 +67,23 @@ class TextDetector(object):
|
|||||||
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||||
postprocess_params["use_dilation"] = args.use_dilation
|
postprocess_params["use_dilation"] = args.use_dilation
|
||||||
postprocess_params["score_mode"] = args.det_db_score_mode
|
postprocess_params["score_mode"] = args.det_db_score_mode
|
||||||
|
elif self.det_algorithm == "DB++":
|
||||||
|
postprocess_params['name'] = 'DBPostProcess'
|
||||||
|
postprocess_params["thresh"] = args.det_db_thresh
|
||||||
|
postprocess_params["box_thresh"] = args.det_db_box_thresh
|
||||||
|
postprocess_params["max_candidates"] = 1000
|
||||||
|
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||||
|
postprocess_params["use_dilation"] = args.use_dilation
|
||||||
|
postprocess_params["score_mode"] = args.det_db_score_mode
|
||||||
|
pre_process_list[1] = {
|
||||||
|
'NormalizeImage': {
|
||||||
|
'std': [1.0, 1.0, 1.0],
|
||||||
|
'mean':
|
||||||
|
[0.48109378172549, 0.45752457890196, 0.40787054090196],
|
||||||
|
'scale': '1./255.',
|
||||||
|
'order': 'hwc'
|
||||||
|
}
|
||||||
|
}
|
||||||
elif self.det_algorithm == "EAST":
|
elif self.det_algorithm == "EAST":
|
||||||
postprocess_params['name'] = 'EASTPostProcess'
|
postprocess_params['name'] = 'EASTPostProcess'
|
||||||
postprocess_params["score_thresh"] = args.det_east_score_thresh
|
postprocess_params["score_thresh"] = args.det_east_score_thresh
|
||||||
@ -231,7 +248,7 @@ class TextDetector(object):
|
|||||||
preds['f_score'] = outputs[1]
|
preds['f_score'] = outputs[1]
|
||||||
preds['f_tco'] = outputs[2]
|
preds['f_tco'] = outputs[2]
|
||||||
preds['f_tvo'] = outputs[3]
|
preds['f_tvo'] = outputs[3]
|
||||||
elif self.det_algorithm in ['DB', 'PSE']:
|
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
|
||||||
preds['maps'] = outputs[0]
|
preds['maps'] = outputs[0]
|
||||||
elif self.det_algorithm == 'FCE':
|
elif self.det_algorithm == 'FCE':
|
||||||
for i, output in enumerate(outputs):
|
for i, output in enumerate(outputs):
|
||||||
|
|||||||
@ -307,7 +307,8 @@ def train(config,
|
|||||||
train_stats.update(stats)
|
train_stats.update(stats)
|
||||||
|
|
||||||
if log_writer is not None and dist.get_rank() == 0:
|
if log_writer is not None and dist.get_rank() == 0:
|
||||||
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
|
log_writer.log_metrics(
|
||||||
|
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
|
||||||
|
|
||||||
if dist.get_rank() == 0 and (
|
if dist.get_rank() == 0 and (
|
||||||
(global_step > 0 and global_step % print_batch_step == 0) or
|
(global_step > 0 and global_step % print_batch_step == 0) or
|
||||||
@ -354,7 +355,8 @@ def train(config,
|
|||||||
|
|
||||||
# logger metric
|
# logger metric
|
||||||
if log_writer is not None:
|
if log_writer is not None:
|
||||||
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
|
log_writer.log_metrics(
|
||||||
|
metrics=cur_metric, prefix="EVAL", step=global_step)
|
||||||
|
|
||||||
if cur_metric[main_indicator] >= best_model_dict[
|
if cur_metric[main_indicator] >= best_model_dict[
|
||||||
main_indicator]:
|
main_indicator]:
|
||||||
@ -377,11 +379,18 @@ def train(config,
|
|||||||
logger.info(best_str)
|
logger.info(best_str)
|
||||||
# logger best metric
|
# logger best metric
|
||||||
if log_writer is not None:
|
if log_writer is not None:
|
||||||
log_writer.log_metrics(metrics={
|
log_writer.log_metrics(
|
||||||
"best_{}".format(main_indicator): best_model_dict[main_indicator]
|
metrics={
|
||||||
}, prefix="EVAL", step=global_step)
|
"best_{}".format(main_indicator):
|
||||||
|
best_model_dict[main_indicator]
|
||||||
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
|
},
|
||||||
|
prefix="EVAL",
|
||||||
|
step=global_step)
|
||||||
|
|
||||||
|
log_writer.log_model(
|
||||||
|
is_best=True,
|
||||||
|
prefix="best_accuracy",
|
||||||
|
metadata=best_model_dict)
|
||||||
|
|
||||||
reader_start = time.time()
|
reader_start = time.time()
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
@ -413,7 +422,8 @@ def train(config,
|
|||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
global_step=global_step)
|
global_step=global_step)
|
||||||
if log_writer is not None:
|
if log_writer is not None:
|
||||||
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
|
log_writer.log_model(
|
||||||
|
is_best=False, prefix='iter_epoch_{}'.format(epoch))
|
||||||
|
|
||||||
best_str = 'best metric, {}'.format(', '.join(
|
best_str = 'best metric, {}'.format(', '.join(
|
||||||
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
||||||
@ -564,7 +574,7 @@ def preprocess(is_train=False):
|
|||||||
assert alg in [
|
assert alg in [
|
||||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
|
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', 'DB++'
|
||||||
]
|
]
|
||||||
|
|
||||||
if use_xpu:
|
if use_xpu:
|
||||||
@ -585,7 +595,8 @@ def preprocess(is_train=False):
|
|||||||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
||||||
log_writer = VDLLogger(save_model_dir)
|
log_writer = VDLLogger(save_model_dir)
|
||||||
loggers.append(log_writer)
|
loggers.append(log_writer)
|
||||||
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
|
if ('use_wandb' in config['Global'] and
|
||||||
|
config['Global']['use_wandb']) or 'wandb' in config:
|
||||||
save_dir = config['Global']['save_model_dir']
|
save_dir = config['Global']['save_model_dir']
|
||||||
wandb_writer_path = "{}/wandb".format(save_dir)
|
wandb_writer_path = "{}/wandb".format(save_dir)
|
||||||
if "wandb" in config:
|
if "wandb" in config:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user