mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-09-05 06:13:06 +00:00
merge init_model and load_dygraph_params to load_model (#4623)
* merge init_model and load_dygraph_params to load_model
This commit is contained in:
parent
1417a3c2cf
commit
ae4167dc32
@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model
|
|||||||
|
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import load_model
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
logger.info(f"FLOPs after pruning: {flops}")
|
logger.info(f"FLOPs after pruning: {flops}")
|
||||||
|
|
||||||
# load pretrain model
|
# load pretrain model
|
||||||
pre_best_model_dict = init_model(config, model, logger, None)
|
load_model(config, model)
|
||||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||||
eval_class)
|
eval_class)
|
||||||
logger.info(f"metric['hmean']: {metric['hmean']}")
|
logger.info(f"metric['hmean']: {metric['hmean']}")
|
||||||
|
@ -32,7 +32,7 @@ from ppocr.losses import build_loss
|
|||||||
from ppocr.optimizer import build_optimizer
|
from ppocr.optimizer import build_optimizer
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import load_model
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
dist.get_world_size()
|
dist.get_world_size()
|
||||||
@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
# build metric
|
# build metric
|
||||||
eval_class = build_metric(config['Metric'])
|
eval_class = build_metric(config['Metric'])
|
||||||
# load pretrain model
|
# load pretrain model
|
||||||
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
pre_best_model_dict = load_model(config, model, optimizer)
|
||||||
|
|
||||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||||
format(len(train_dataloader), len(valid_dataloader)))
|
format(len(train_dataloader), len(valid_dataloader)))
|
||||||
|
@ -28,7 +28,7 @@ from paddle.jit import to_static
|
|||||||
|
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import load_model
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
from tools.program import load_config, merge_config, ArgsParser
|
from tools.program import load_config, merge_config, ArgsParser
|
||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
@ -101,7 +101,7 @@ def main():
|
|||||||
quanter = QAT(config=quant_config)
|
quanter = QAT(config=quant_config)
|
||||||
quanter.quantize(model)
|
quanter.quantize(model)
|
||||||
|
|
||||||
init_model(config, model)
|
load_model(config, model)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# build metric
|
# build metric
|
||||||
|
@ -37,7 +37,7 @@ from ppocr.losses import build_loss
|
|||||||
from ppocr.optimizer import build_optimizer
|
from ppocr.optimizer import build_optimizer
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import load_model
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
from paddleslim.dygraph.quant import QAT
|
from paddleslim.dygraph.quant import QAT
|
||||||
|
|
||||||
@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
# build metric
|
# build metric
|
||||||
eval_class = build_metric(config['Metric'])
|
eval_class = build_metric(config['Metric'])
|
||||||
# load pretrain model
|
# load pretrain model
|
||||||
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
pre_best_model_dict = load_model(config, model, optimizer)
|
||||||
|
|
||||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||||
format(len(train_dataloader), len(valid_dataloader)))
|
format(len(train_dataloader), len(valid_dataloader)))
|
||||||
|
@ -37,7 +37,7 @@ from ppocr.losses import build_loss
|
|||||||
from ppocr.optimizer import build_optimizer
|
from ppocr.optimizer import build_optimizer
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import load_model
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
import paddleslim
|
import paddleslim
|
||||||
from paddleslim.dygraph.quant import QAT
|
from paddleslim.dygraph.quant import QAT
|
||||||
|
@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
|
|||||||
from ppocr.modeling.necks import build_neck
|
from ppocr.modeling.necks import build_neck
|
||||||
from ppocr.modeling.heads import build_head
|
from ppocr.modeling.heads import build_head
|
||||||
from .base_model import BaseModel
|
from .base_model import BaseModel
|
||||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
from ppocr.utils.save_load import load_pretrained_params
|
||||||
|
|
||||||
__all__ = ['DistillationModel']
|
__all__ = ['DistillationModel']
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ import paddle
|
|||||||
|
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
|
|
||||||
__all__ = ['init_model', 'save_model', 'load_dygraph_params']
|
__all__ = ['load_model']
|
||||||
|
|
||||||
|
|
||||||
def _mkdir_if_not_exist(path, logger):
|
def _mkdir_if_not_exist(path, logger):
|
||||||
@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
|
|||||||
raise OSError('Failed to mkdir {}'.format(path))
|
raise OSError('Failed to mkdir {}'.format(path))
|
||||||
|
|
||||||
|
|
||||||
def init_model(config, model, optimizer=None, lr_scheduler=None):
|
def load_model(config, model, optimizer=None):
|
||||||
"""
|
"""
|
||||||
load model from checkpoint or pretrained_model
|
load model from checkpoint or pretrained_model
|
||||||
"""
|
"""
|
||||||
@ -54,15 +54,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
|
|||||||
pretrained_model = global_config.get('pretrained_model')
|
pretrained_model = global_config.get('pretrained_model')
|
||||||
best_model_dict = {}
|
best_model_dict = {}
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
assert os.path.exists(checkpoints + ".pdparams"), \
|
if checkpoints.endswith('pdparams'):
|
||||||
"Given dir {}.pdparams not exist.".format(checkpoints)
|
checkpoints = checkpoints.replace('.pdparams', '')
|
||||||
assert os.path.exists(checkpoints + ".pdopt"), \
|
assert os.path.exists(checkpoints + ".pdopt"), \
|
||||||
"Given dir {}.pdopt not exist.".format(checkpoints)
|
f"The {checkpoints}.pdopt does not exists!"
|
||||||
para_dict = paddle.load(checkpoints + '.pdparams')
|
load_pretrained_params(model, checkpoints)
|
||||||
opti_dict = paddle.load(checkpoints + '.pdopt')
|
optim_dict = paddle.load(checkpoints + '.pdopt')
|
||||||
model.set_state_dict(para_dict)
|
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
optimizer.set_state_dict(opti_dict)
|
optimizer.set_state_dict(optim_dict)
|
||||||
|
|
||||||
if os.path.exists(checkpoints + '.states'):
|
if os.path.exists(checkpoints + '.states'):
|
||||||
with open(checkpoints + '.states', 'rb') as f:
|
with open(checkpoints + '.states', 'rb') as f:
|
||||||
@ -73,70 +72,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
|
|||||||
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
||||||
logger.info("resume from {}".format(checkpoints))
|
logger.info("resume from {}".format(checkpoints))
|
||||||
elif pretrained_model:
|
elif pretrained_model:
|
||||||
if not isinstance(pretrained_model, list):
|
load_pretrained_params(model, pretrained_model)
|
||||||
pretrained_model = [pretrained_model]
|
|
||||||
for pretrained in pretrained_model:
|
|
||||||
if not (os.path.isdir(pretrained) or
|
|
||||||
os.path.exists(pretrained + '.pdparams')):
|
|
||||||
raise ValueError("Model pretrain path {} does not "
|
|
||||||
"exists.".format(pretrained))
|
|
||||||
param_state_dict = paddle.load(pretrained + '.pdparams')
|
|
||||||
model.set_state_dict(param_state_dict)
|
|
||||||
logger.info("load pretrained model from {}".format(
|
|
||||||
pretrained_model))
|
|
||||||
else:
|
else:
|
||||||
logger.info('train from scratch')
|
logger.info('train from scratch')
|
||||||
return best_model_dict
|
return best_model_dict
|
||||||
|
|
||||||
|
|
||||||
def load_dygraph_params(config, model, logger, optimizer):
|
|
||||||
ckp = config['Global']['checkpoints']
|
|
||||||
if ckp and os.path.exists(ckp + ".pdparams"):
|
|
||||||
pre_best_model_dict = init_model(config, model, optimizer)
|
|
||||||
return pre_best_model_dict
|
|
||||||
else:
|
|
||||||
pm = config['Global']['pretrained_model']
|
|
||||||
if pm is None:
|
|
||||||
return {}
|
|
||||||
if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
|
|
||||||
logger.info(f"The pretrained_model {pm} does not exists!")
|
|
||||||
return {}
|
|
||||||
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
|
|
||||||
params = paddle.load(pm)
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
new_state_dict = {}
|
|
||||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
|
||||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
|
||||||
new_state_dict[k1] = params[k2]
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
|
||||||
)
|
|
||||||
model.set_state_dict(new_state_dict)
|
|
||||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def load_pretrained_params(model, path):
|
def load_pretrained_params(model, path):
|
||||||
if path is None:
|
logger = get_logger()
|
||||||
return False
|
if path.endswith('pdparams'):
|
||||||
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
|
path = path.replace('.pdparams', '')
|
||||||
print(f"The pretrained_model {path} does not exists!")
|
assert os.path.exists(path + ".pdparams"), \
|
||||||
return False
|
f"The {path}.pdparams does not exists!"
|
||||||
|
|
||||||
path = path if path.endswith('.pdparams') else path + '.pdparams'
|
params = paddle.load(path + '.pdparams')
|
||||||
params = paddle.load(path)
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||||
new_state_dict[k1] = params[k2]
|
new_state_dict[k1] = params[k2]
|
||||||
else:
|
else:
|
||||||
print(
|
logger.info(
|
||||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||||
)
|
)
|
||||||
model.set_state_dict(new_state_dict)
|
model.set_state_dict(new_state_dict)
|
||||||
print(f"load pretrain successful from {path}")
|
logger.info(f"load pretrain successful from {path}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
|
|||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
from ppocr.utils.save_load import load_model
|
||||||
from ppocr.utils.utility import print_dict
|
from ppocr.utils.utility import print_dict
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
model_type = None
|
model_type = None
|
||||||
|
|
||||||
best_model_dict = load_dygraph_params(config, model, logger, None)
|
best_model_dict = load_model(config, model)
|
||||||
if len(best_model_dict):
|
if len(best_model_dict):
|
||||||
logger.info('metric in ckpt ***************')
|
logger.info('metric in ckpt ***************')
|
||||||
for k, v in best_model_dict.items():
|
for k, v in best_model_dict.items():
|
||||||
|
@ -27,7 +27,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
|||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
from ppocr.utils.save_load import load_model
|
||||||
from ppocr.utils.utility import print_dict
|
from ppocr.utils.utility import print_dict
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ def main():
|
|||||||
|
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
best_model_dict = load_dygraph_params(config, model, logger, None)
|
best_model_dict = load_model(config, model)
|
||||||
if len(best_model_dict):
|
if len(best_model_dict):
|
||||||
logger.info('metric in ckpt ***************')
|
logger.info('metric in ckpt ***************')
|
||||||
for k, v in best_model_dict.items():
|
for k, v in best_model_dict.items():
|
||||||
|
@ -26,7 +26,7 @@ from paddle.jit import to_static
|
|||||||
|
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import load_model
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
from tools.program import load_config, merge_config, ArgsParser
|
from tools.program import load_config, merge_config, ArgsParser
|
||||||
|
|
||||||
@ -107,7 +107,7 @@ def main():
|
|||||||
else: # base rec model
|
else: # base rec model
|
||||||
config["Architecture"]["Head"]["out_channels"] = char_num
|
config["Architecture"]["Head"]["out_channels"] = char_num
|
||||||
model = build_model(config["Architecture"])
|
model = build_model(config["Architecture"])
|
||||||
init_model(config, model)
|
load_model(config, model)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
save_path = config["Global"]["save_inference_dir"]
|
save_path = config["Global"]["save_inference_dir"]
|
||||||
|
@ -32,7 +32,7 @@ import paddle
|
|||||||
from ppocr.data import create_operators, transform
|
from ppocr.data import create_operators, transform
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import load_model
|
||||||
from ppocr.utils.utility import get_image_file_list
|
from ppocr.utils.utility import get_image_file_list
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ def main():
|
|||||||
# build model
|
# build model
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
init_model(config, model)
|
load_model(config, model)
|
||||||
|
|
||||||
# create data ops
|
# create data ops
|
||||||
transforms = []
|
transforms = []
|
||||||
|
@ -34,7 +34,7 @@ import paddle
|
|||||||
from ppocr.data import create_operators, transform
|
from ppocr.data import create_operators, transform
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
from ppocr.utils.save_load import load_model
|
||||||
from ppocr.utils.utility import get_image_file_list
|
from ppocr.utils.utility import get_image_file_list
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ def main():
|
|||||||
# build model
|
# build model
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
_ = load_dygraph_params(config, model, logger, None)
|
load_model(config, model)
|
||||||
# build post process
|
# build post process
|
||||||
post_process_class = build_post_process(config['PostProcess'])
|
post_process_class = build_post_process(config['PostProcess'])
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ import paddle
|
|||||||
from ppocr.data import create_operators, transform
|
from ppocr.data import create_operators, transform
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import load_model
|
||||||
from ppocr.utils.utility import get_image_file_list
|
from ppocr.utils.utility import get_image_file_list
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ def main():
|
|||||||
# build model
|
# build model
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
init_model(config, model)
|
load_model(config, model)
|
||||||
|
|
||||||
# build post process
|
# build post process
|
||||||
post_process_class = build_post_process(config['PostProcess'],
|
post_process_class = build_post_process(config['PostProcess'],
|
||||||
|
@ -33,7 +33,7 @@ import paddle
|
|||||||
from ppocr.data import create_operators, transform
|
from ppocr.data import create_operators, transform
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import load_model
|
||||||
from ppocr.utils.utility import get_image_file_list
|
from ppocr.utils.utility import get_image_file_list
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ def main():
|
|||||||
|
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
init_model(config, model)
|
load_model(config, model)
|
||||||
|
|
||||||
# create data ops
|
# create data ops
|
||||||
transforms = []
|
transforms = []
|
||||||
@ -75,9 +75,7 @@ def main():
|
|||||||
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
|
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
|
||||||
]
|
]
|
||||||
elif config['Architecture']['algorithm'] == "SAR":
|
elif config['Architecture']['algorithm'] == "SAR":
|
||||||
op[op_name]['keep_keys'] = [
|
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
|
||||||
'image', 'valid_ratio'
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
op[op_name]['keep_keys'] = ['image']
|
op[op_name]['keep_keys'] = ['image']
|
||||||
transforms.append(op)
|
transforms.append(op)
|
||||||
|
@ -34,11 +34,12 @@ from paddle.jit import to_static
|
|||||||
from ppocr.data import create_operators, transform
|
from ppocr.data import create_operators, transform
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import load_model
|
||||||
from ppocr.utils.utility import get_image_file_list
|
from ppocr.utils.utility import get_image_file_list
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
def main(config, device, logger, vdl_writer):
|
def main(config, device, logger, vdl_writer):
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
|
|
||||||
@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
|
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
init_model(config, model, logger)
|
load_model(config, model)
|
||||||
|
|
||||||
# create data ops
|
# create data ops
|
||||||
transforms = []
|
transforms = []
|
||||||
@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
config, device, logger, vdl_writer = program.preprocess()
|
config, device, logger, vdl_writer = program.preprocess()
|
||||||
main(config, device, logger, vdl_writer)
|
main(config, device, logger, vdl_writer)
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ from ppocr.losses import build_loss
|
|||||||
from ppocr.optimizer import build_optimizer
|
from ppocr.optimizer import build_optimizer
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
from ppocr.utils.save_load import load_model
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
dist.get_world_size()
|
dist.get_world_size()
|
||||||
@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
# build metric
|
# build metric
|
||||||
eval_class = build_metric(config['Metric'])
|
eval_class = build_metric(config['Metric'])
|
||||||
# load pretrain model
|
# load pretrain model
|
||||||
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
pre_best_model_dict = load_model(config, model, optimizer)
|
||||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||||
if valid_dataloader is not None:
|
if valid_dataloader is not None:
|
||||||
logger.info('valid dataloader has {} iters'.format(
|
logger.info('valid dataloader has {} iters'.format(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user