mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-31 01:39:11 +00:00 
			
		
		
		
	Merge pull request #5166 from tink2123/fix_save_load
fix save load, update det pretrain
This commit is contained in:
		
						commit
						a11fbc0f7f
					
				| @ -78,11 +78,11 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 | |||||||
| cd PaddleOCR/ | cd PaddleOCR/ | ||||||
| # 根据backbone的不同选择下载对应的预训练模型 | # 根据backbone的不同选择下载对应的预训练模型 | ||||||
| # 下载MobileNetV3的预训练模型 | # 下载MobileNetV3的预训练模型 | ||||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams | wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams | ||||||
| # 或,下载ResNet18_vd的预训练模型 | # 或,下载ResNet18_vd的预训练模型 | ||||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams | wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams | ||||||
| # 或,下载ResNet50_vd的预训练模型 | # 或,下载ResNet50_vd的预训练模型 | ||||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams | wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| <a name="2-----"></a> | <a name="2-----"></a> | ||||||
|  | |||||||
| @ -67,11 +67,11 @@ And the responding download link of backbone pretrain weights can be found in (h | |||||||
| ```shell | ```shell | ||||||
| cd PaddleOCR/ | cd PaddleOCR/ | ||||||
| # Download the pre-trained model of MobileNetV3 | # Download the pre-trained model of MobileNetV3 | ||||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams | wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams | ||||||
| # or, download the pre-trained model of ResNet18_vd | # or, download the pre-trained model of ResNet18_vd | ||||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams | wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams | ||||||
| # or, download the pre-trained model of ResNet50_vd | # or, download the pre-trained model of ResNet50_vd | ||||||
| wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams | wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams | ||||||
| 
 | 
 | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -111,13 +111,16 @@ def load_pretrained_params(model, path): | |||||||
|     params = paddle.load(path + '.pdparams') |     params = paddle.load(path + '.pdparams') | ||||||
|     state_dict = model.state_dict() |     state_dict = model.state_dict() | ||||||
|     new_state_dict = {} |     new_state_dict = {} | ||||||
|     for k1, k2 in zip(state_dict.keys(), params.keys()): |     for k1 in params.keys(): | ||||||
|         if list(state_dict[k1].shape) == list(params[k2].shape): |         if k1 not in state_dict.keys(): | ||||||
|             new_state_dict[k1] = params[k2] |             logger.warning("The pretrained params {} not in model".format(k1)) | ||||||
|  |         else: | ||||||
|  |             if list(state_dict[k1].shape) == list(params[k1].shape): | ||||||
|  |                 new_state_dict[k1] = params[k1] | ||||||
|             else: |             else: | ||||||
|                 logger.warning( |                 logger.warning( | ||||||
|                     "The shape of model params {} {} not matched with loaded params {} {} !". |                     "The shape of model params {} {} not matched with loaded params {} {} !". | ||||||
|                 format(k1, state_dict[k1].shape, k2, params[k2].shape)) |                     format(k1, state_dict[k1].shape, k1, params[k1].shape)) | ||||||
|     model.set_state_dict(new_state_dict) |     model.set_state_dict(new_state_dict) | ||||||
|     logger.info("load pretrain successful from {}".format(path)) |     logger.info("load pretrain successful from {}".format(path)) | ||||||
|     return model |     return model | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 andyjpaddle
						andyjpaddle