| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | 
					
						
							|  |  |  | # | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | # | 
					
						
							|  |  |  | #    http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from __future__ import absolute_import | 
					
						
							|  |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import errno | 
					
						
							|  |  |  | import os | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | import pickle | 
					
						
							|  |  |  | import six | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | import paddle | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-05 06:52:45 +00:00
										 |  |  | from ppocr.utils.logging import get_logger | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  | __all__ = ['load_model'] | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | def _mkdir_if_not_exist(path, logger): | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     mkdir if not exists, ignore the exception when multiprocess mkdir together | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     if not os.path.exists(path): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             os.makedirs(path) | 
					
						
							|  |  |  |         except OSError as e: | 
					
						
							|  |  |  |             if e.errno == errno.EEXIST and os.path.isdir(path): | 
					
						
							|  |  |  |                 logger.warning( | 
					
						
							|  |  |  |                     'be happy if some process has already created {}'.format( | 
					
						
							|  |  |  |                         path)) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 raise OSError('Failed to mkdir {}'.format(path)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  | def load_model(config, model, optimizer=None): | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     load model from checkpoint or pretrained_model | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2021-06-05 06:52:45 +00:00
										 |  |  |     logger = get_logger() | 
					
						
							| 
									
										
										
										
											2021-05-23 19:16:01 +08:00
										 |  |  |     global_config = config['Global'] | 
					
						
							|  |  |  |     checkpoints = global_config.get('checkpoints') | 
					
						
							|  |  |  |     pretrained_model = global_config.get('pretrained_model') | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     best_model_dict = {} | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  |     if checkpoints: | 
					
						
							| 
									
										
										
										
											2021-11-24 10:23:06 +00:00
										 |  |  |         if checkpoints.endswith('.pdparams'): | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  |             checkpoints = checkpoints.replace('.pdparams', '') | 
					
						
							| 
									
										
										
										
											2021-11-24 09:28:34 +00:00
										 |  |  |         assert os.path.exists(checkpoints + ".pdparams"), \ | 
					
						
							| 
									
										
										
										
											2021-11-24 09:40:33 +00:00
										 |  |  |             "The {}.pdparams does not exists!".format(checkpoints) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-24 09:28:34 +00:00
										 |  |  |         # load params from trained model | 
					
						
							|  |  |  |         params = paddle.load(checkpoints + '.pdparams') | 
					
						
							|  |  |  |         state_dict = model.state_dict() | 
					
						
							|  |  |  |         new_state_dict = {} | 
					
						
							|  |  |  |         for key, value in state_dict.items(): | 
					
						
							|  |  |  |             if key not in params: | 
					
						
							| 
									
										
										
										
											2021-11-24 09:40:33 +00:00
										 |  |  |                 logger.warning("{} not in loaded params {} !".format( | 
					
						
							|  |  |  |                     key, params.keys())) | 
					
						
							| 
									
										
										
										
											2021-11-24 09:28:34 +00:00
										 |  |  |             pre_value = params[key] | 
					
						
							|  |  |  |             if list(value.shape) == list(pre_value.shape): | 
					
						
							|  |  |  |                 new_state_dict[key] = pre_value | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 logger.warning( | 
					
						
							| 
									
										
										
										
											2021-11-24 09:40:33 +00:00
										 |  |  |                     "The shape of model params {} {} not matched with loaded params shape {} !". | 
					
						
							|  |  |  |                     format(key, value.shape, pre_value.shape)) | 
					
						
							| 
									
										
										
										
											2021-11-24 09:28:34 +00:00
										 |  |  |         model.set_state_dict(new_state_dict) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  |         optim_dict = paddle.load(checkpoints + '.pdopt') | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         if optimizer is not None: | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  |             optimizer.set_state_dict(optim_dict) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if os.path.exists(checkpoints + '.states'): | 
					
						
							|  |  |  |             with open(checkpoints + '.states', 'rb') as f: | 
					
						
							|  |  |  |                 states_dict = pickle.load(f) if six.PY2 else pickle.load( | 
					
						
							|  |  |  |                     f, encoding='latin1') | 
					
						
							|  |  |  |             best_model_dict = states_dict.get('best_model_dict', {}) | 
					
						
							|  |  |  |             if 'epoch' in states_dict: | 
					
						
							|  |  |  |                 best_model_dict['start_epoch'] = states_dict['epoch'] + 1 | 
					
						
							|  |  |  |         logger.info("resume from {}".format(checkpoints)) | 
					
						
							|  |  |  |     elif pretrained_model: | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  |         load_pretrained_params(model, pretrained_model) | 
					
						
							| 
									
										
										
										
											2020-08-12 07:08:07 +00:00
										 |  |  |     else: | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         logger.info('train from scratch') | 
					
						
							|  |  |  |     return best_model_dict | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-07 01:54:03 +00:00
										 |  |  | def load_pretrained_params(model, path): | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  |     logger = get_logger() | 
					
						
							| 
									
										
										
										
											2021-11-24 10:23:06 +00:00
										 |  |  |     if path.endswith('.pdparams'): | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  |         path = path.replace('.pdparams', '') | 
					
						
							|  |  |  |     assert os.path.exists(path + ".pdparams"), \ | 
					
						
							| 
									
										
										
										
											2021-11-24 09:40:33 +00:00
										 |  |  |         "The {}.pdparams does not exists!".format(path) | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     params = paddle.load(path + '.pdparams') | 
					
						
							| 
									
										
										
										
											2021-07-07 01:54:03 +00:00
										 |  |  |     state_dict = model.state_dict() | 
					
						
							|  |  |  |     new_state_dict = {} | 
					
						
							|  |  |  |     for k1, k2 in zip(state_dict.keys(), params.keys()): | 
					
						
							|  |  |  |         if list(state_dict[k1].shape) == list(params[k2].shape): | 
					
						
							|  |  |  |             new_state_dict[k1] = params[k2] | 
					
						
							| 
									
										
										
										
											2021-07-07 07:54:02 +00:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2021-11-24 09:28:34 +00:00
										 |  |  |             logger.warning( | 
					
						
							| 
									
										
										
										
											2021-11-24 09:40:33 +00:00
										 |  |  |                 "The shape of model params {} {} not matched with loaded params {} {} !". | 
					
						
							|  |  |  |                 format(k1, state_dict[k1].shape, k2, params[k2].shape)) | 
					
						
							| 
									
										
										
										
											2021-07-07 01:54:03 +00:00
										 |  |  |     model.set_state_dict(new_state_dict) | 
					
						
							| 
									
										
										
										
											2021-11-24 09:40:33 +00:00
										 |  |  |     logger.info("load pretrain successful from {}".format(path)) | 
					
						
							| 
									
										
										
										
											2021-07-08 14:32:44 +00:00
										 |  |  |     return model | 
					
						
							| 
									
										
										
										
											2021-06-28 20:44:06 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-09 23:51:48 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-26 21:13:21 -05:00
										 |  |  | def save_model(model, | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |                optimizer, | 
					
						
							|  |  |  |                model_path, | 
					
						
							|  |  |  |                logger, | 
					
						
							|  |  |  |                is_best=False, | 
					
						
							|  |  |  |                prefix='ppocr', | 
					
						
							|  |  |  |                **kwargs): | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     save model to the target path | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     _mkdir_if_not_exist(model_path, logger) | 
					
						
							|  |  |  |     model_prefix = os.path.join(model_path, prefix) | 
					
						
							| 
									
										
										
										
											2021-04-26 21:13:21 -05:00
										 |  |  |     paddle.save(model.state_dict(), model_prefix + '.pdparams') | 
					
						
							| 
									
										
										
										
											2020-11-09 13:27:31 +08:00
										 |  |  |     paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # save metric and config | 
					
						
							|  |  |  |     with open(model_prefix + '.states', 'wb') as f: | 
					
						
							|  |  |  |         pickle.dump(kwargs, f, protocol=2) | 
					
						
							|  |  |  |     if is_best: | 
					
						
							|  |  |  |         logger.info('save best model is to {}'.format(model_prefix)) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         logger.info("save model in {}".format(model_prefix)) |