mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-31 01:39:11 +00:00 
			
		
		
		
	Merge pull request #5166 from tink2123/fix_save_load
fix save load, update det pretrain
This commit is contained in:
		
						commit
						a11fbc0f7f
					
				| @ -78,11 +78,11 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 | ||||
| cd PaddleOCR/ | ||||
| # 根据backbone的不同选择下载对应的预训练模型 | ||||
| # 下载MobileNetV3的预训练模型 | ||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams | ||||
| wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams | ||||
| # 或,下载ResNet18_vd的预训练模型 | ||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams | ||||
| wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams | ||||
| # 或,下载ResNet50_vd的预训练模型 | ||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams | ||||
| wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams | ||||
| ``` | ||||
| 
 | ||||
| <a name="2-----"></a> | ||||
|  | ||||
| @ -67,11 +67,11 @@ And the responding download link of backbone pretrain weights can be found in (h | ||||
| ```shell | ||||
| cd PaddleOCR/ | ||||
| # Download the pre-trained model of MobileNetV3 | ||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams | ||||
| wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams | ||||
| # or, download the pre-trained model of ResNet18_vd | ||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams | ||||
| wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams | ||||
| # or, download the pre-trained model of ResNet50_vd | ||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams | ||||
| wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams | ||||
| 
 | ||||
| ``` | ||||
| 
 | ||||
|  | ||||
| @ -111,13 +111,16 @@ def load_pretrained_params(model, path): | ||||
|     params = paddle.load(path + '.pdparams') | ||||
|     state_dict = model.state_dict() | ||||
|     new_state_dict = {} | ||||
|     for k1, k2 in zip(state_dict.keys(), params.keys()): | ||||
|         if list(state_dict[k1].shape) == list(params[k2].shape): | ||||
|             new_state_dict[k1] = params[k2] | ||||
|     for k1 in params.keys(): | ||||
|         if k1 not in state_dict.keys(): | ||||
|             logger.warning("The pretrained params {} not in model".format(k1)) | ||||
|         else: | ||||
|             if list(state_dict[k1].shape) == list(params[k1].shape): | ||||
|                 new_state_dict[k1] = params[k1] | ||||
|             else: | ||||
|                 logger.warning( | ||||
|                     "The shape of model params {} {} not matched with loaded params {} {} !". | ||||
|                 format(k1, state_dict[k1].shape, k2, params[k2].shape)) | ||||
|                     format(k1, state_dict[k1].shape, k1, params[k1].shape)) | ||||
|     model.set_state_dict(new_state_dict) | ||||
|     logger.info("load pretrain successful from {}".format(path)) | ||||
|     return model | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 andyjpaddle
						andyjpaddle