mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-02 02:39:16 +00:00
fix amp vqa
This commit is contained in:
parent
91600fccb9
commit
adeb8a17c9
@ -255,6 +255,8 @@ def train(config,
|
||||
with paddle.amp.auto_cast():
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie", 'vqa']:
|
||||
preds = model(batch)
|
||||
else:
|
||||
preds = model(images)
|
||||
else:
|
||||
@ -307,7 +309,8 @@ def train(config,
|
||||
train_stats.update(stats)
|
||||
|
||||
if log_writer is not None and dist.get_rank() == 0:
|
||||
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
|
||||
log_writer.log_metrics(
|
||||
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
|
||||
|
||||
if dist.get_rank() == 0 and (
|
||||
(global_step > 0 and global_step % print_batch_step == 0) or
|
||||
@ -354,7 +357,8 @@ def train(config,
|
||||
|
||||
# logger metric
|
||||
if log_writer is not None:
|
||||
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
|
||||
log_writer.log_metrics(
|
||||
metrics=cur_metric, prefix="EVAL", step=global_step)
|
||||
|
||||
if cur_metric[main_indicator] >= best_model_dict[
|
||||
main_indicator]:
|
||||
@ -377,11 +381,18 @@ def train(config,
|
||||
logger.info(best_str)
|
||||
# logger best metric
|
||||
if log_writer is not None:
|
||||
log_writer.log_metrics(metrics={
|
||||
"best_{}".format(main_indicator): best_model_dict[main_indicator]
|
||||
}, prefix="EVAL", step=global_step)
|
||||
|
||||
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
|
||||
log_writer.log_metrics(
|
||||
metrics={
|
||||
"best_{}".format(main_indicator):
|
||||
best_model_dict[main_indicator]
|
||||
},
|
||||
prefix="EVAL",
|
||||
step=global_step)
|
||||
|
||||
log_writer.log_model(
|
||||
is_best=True,
|
||||
prefix="best_accuracy",
|
||||
metadata=best_model_dict)
|
||||
|
||||
reader_start = time.time()
|
||||
if dist.get_rank() == 0:
|
||||
@ -413,7 +424,8 @@ def train(config,
|
||||
epoch=epoch,
|
||||
global_step=global_step)
|
||||
if log_writer is not None:
|
||||
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
|
||||
log_writer.log_model(
|
||||
is_best=False, prefix='iter_epoch_{}'.format(epoch))
|
||||
|
||||
best_str = 'best metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
||||
@ -585,7 +597,8 @@ def preprocess(is_train=False):
|
||||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
||||
log_writer = VDLLogger(save_model_dir)
|
||||
loggers.append(log_writer)
|
||||
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
|
||||
if ('use_wandb' in config['Global'] and
|
||||
config['Global']['use_wandb']) or 'wandb' in config:
|
||||
save_dir = config['Global']['save_model_dir']
|
||||
wandb_writer_path = "{}/wandb".format(save_dir)
|
||||
if "wandb" in config:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user