mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2026-01-07 12:37:13 +00:00
fix bug in amp eval
This commit is contained in:
parent
9e4ae9dc12
commit
aa7a0cb2a8
@ -8,7 +8,7 @@ Global:
|
||||
# evaluation is run every 835 iterations
|
||||
eval_batch_step: [0, 4000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
|
||||
pretrained_model: pretrain_models/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy.pdparams
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
|
||||
@ -13,7 +13,7 @@ train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
|
||||
norm_train:tools/train.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
@ -27,7 +27,7 @@ null:null
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Architecture.Backbone.checkpoints:
|
||||
norm_export:tools/export_model.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o
|
||||
norm_export:tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o
|
||||
quant_export:
|
||||
fpgm_export:
|
||||
distill_export:null
|
||||
|
||||
@ -154,12 +154,13 @@ def check_xpu(use_xpu):
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def to_float32(preds):
|
||||
if isinstance(preds, dict):
|
||||
for k in preds:
|
||||
if isinstance(preds[k], dict) or isinstance(preds[k], list):
|
||||
preds[k] = to_float32(preds[k])
|
||||
else:
|
||||
elif isinstance(preds[k], paddle.Tensor):
|
||||
preds[k] = preds[k].astype(paddle.float32)
|
||||
elif isinstance(preds, list):
|
||||
for k in range(len(preds)):
|
||||
@ -167,12 +168,13 @@ def to_float32(preds):
|
||||
preds[k] = to_float32(preds[k])
|
||||
elif isinstance(preds[k], list):
|
||||
preds[k] = to_float32(preds[k])
|
||||
else:
|
||||
elif isinstance(preds[k], paddle.Tensor):
|
||||
preds[k] = preds[k].astype(paddle.float32)
|
||||
else:
|
||||
elif isinstance(preds[k], paddle.Tensor):
|
||||
preds = preds.astype(paddle.float32)
|
||||
return preds
|
||||
|
||||
|
||||
def train(config,
|
||||
train_dataloader,
|
||||
valid_dataloader,
|
||||
@ -370,7 +372,8 @@ def train(config,
|
||||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
extra_input=extra_input)
|
||||
extra_input=extra_input,
|
||||
scaler=scaler)
|
||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||
logger.info(cur_metric_str)
|
||||
@ -460,7 +463,8 @@ def eval(model,
|
||||
post_process_class,
|
||||
eval_class,
|
||||
model_type=None,
|
||||
extra_input=False):
|
||||
extra_input=False,
|
||||
scaler=None):
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
total_frame = 0.0
|
||||
@ -477,12 +481,24 @@ def eval(model,
|
||||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie", 'vqa']:
|
||||
preds = model(batch)
|
||||
|
||||
# use amp
|
||||
if scaler:
|
||||
with paddle.amp.auto_cast(level='O2'):
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie", 'vqa']:
|
||||
preds = model(batch)
|
||||
else:
|
||||
preds = model(images)
|
||||
else:
|
||||
preds = model(images)
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie", 'vqa']:
|
||||
preds = model(batch)
|
||||
else:
|
||||
preds = model(images)
|
||||
|
||||
batch_numpy = []
|
||||
for item in batch:
|
||||
if isinstance(item, paddle.Tensor):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user