| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #     http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | from __future__ import absolute_import | 
					
						
							|  |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from paddle import nn | 
					
						
							|  |  |  | from ppocr.modeling.transforms import build_transform | 
					
						
							|  |  |  | from ppocr.modeling.backbones import build_backbone | 
					
						
							|  |  |  | from ppocr.modeling.necks import build_neck | 
					
						
							|  |  |  | from ppocr.modeling.heads import build_head | 
					
						
							|  |  |  | from .base_model import BaseModel | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  | from ppocr.utils.save_load import load_pretrained_params | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | __all__ = ['DistillationModel'] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class DistillationModel(nn.Layer): | 
					
						
							|  |  |  |     def __init__(self, config): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         the module for OCR distillation. | 
					
						
							|  |  |  |         args: | 
					
						
							|  |  |  |             config (dict): the super parameters for module. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							| 
									
										
										
										
											2021-06-03 06:53:24 +00:00
										 |  |  |         self.model_list = [] | 
					
						
							|  |  |  |         self.model_name_list = [] | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  |         for key in config["Models"]: | 
					
						
							|  |  |  |             model_config = config["Models"][key] | 
					
						
							| 
									
										
										
										
											2021-06-03 05:30:43 +00:00
										 |  |  |             freeze_params = False | 
					
						
							|  |  |  |             pretrained = None | 
					
						
							|  |  |  |             if "freeze_params" in model_config: | 
					
						
							|  |  |  |                 freeze_params = model_config.pop("freeze_params") | 
					
						
							|  |  |  |             if "pretrained" in model_config: | 
					
						
							|  |  |  |                 pretrained = model_config.pop("pretrained") | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  |             model = BaseModel(model_config) | 
					
						
							| 
									
										
										
										
											2021-06-03 05:30:43 +00:00
										 |  |  |             if pretrained is not None: | 
					
						
							| 
									
										
										
										
											2021-07-27 02:57:53 +00:00
										 |  |  |                 load_pretrained_params(model, pretrained) | 
					
						
							| 
									
										
										
										
											2021-06-03 05:30:43 +00:00
										 |  |  |             if freeze_params: | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  |                 for param in model.parameters(): | 
					
						
							|  |  |  |                     param.trainable = False | 
					
						
							| 
									
										
										
										
											2021-06-03 06:53:24 +00:00
										 |  |  |             self.model_list.append(self.add_sublayer(key, model)) | 
					
						
							|  |  |  |             self.model_name_list.append(key) | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  |     def forward(self, x, data=None): | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  |         result_dict = dict() | 
					
						
							| 
									
										
										
										
											2021-06-03 06:53:24 +00:00
										 |  |  |         for idx, model_name in enumerate(self.model_name_list): | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  |             result_dict[model_name] = self.model_list[idx](x, data) | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  |         return result_dict |