| 
									
										
										
										
											2021-06-21 12:20:25 +00:00
										 |  |  | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #     http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | from __future__ import absolute_import | 
					
						
							|  |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							| 
									
										
										
										
											2022-08-12 10:49:54 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | from paddle import nn | 
					
						
							| 
									
										
										
										
											2020-12-09 20:26:40 +08:00
										 |  |  | from ppocr.modeling.transforms import build_transform | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | from ppocr.modeling.backbones import build_backbone | 
					
						
							|  |  |  | from ppocr.modeling.necks import build_neck | 
					
						
							|  |  |  | from ppocr.modeling.heads import build_head | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  | __all__ = ['BaseModel'] | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-10 17:18:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  | class BaseModel(nn.Layer): | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     def __init__(self, config): | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |         the module for OCR. | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         args: | 
					
						
							|  |  |  |             config (dict): the super parameters for module. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |         super(BaseModel, self).__init__() | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         in_channels = config.get('in_channels', 3) | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |         model_type = config['model_type'] | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         # build transfrom, | 
					
						
							|  |  |  |         # for rec, transfrom can be TPS,None | 
					
						
							|  |  |  |         # for det and cls, transfrom shoule to be None, | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |         # if you make model differently, you can use transfrom in det and cls | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         if 'Transform' not in config or config['Transform'] is None: | 
					
						
							|  |  |  |             self.use_transform = False | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.use_transform = True | 
					
						
							|  |  |  |             config['Transform']['in_channels'] = in_channels | 
					
						
							|  |  |  |             self.transform = build_transform(config['Transform']) | 
					
						
							|  |  |  |             in_channels = self.transform.out_channels | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # build backbone, backbone is need for del, rec and cls | 
					
						
							| 
									
										
										
										
											2022-08-12 10:49:54 +08:00
										 |  |  |         if 'Backbone' not in config or config['Backbone'] is None: | 
					
						
							|  |  |  |             self.use_backbone = False | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.use_backbone = True | 
					
						
							|  |  |  |             config["Backbone"]['in_channels'] = in_channels | 
					
						
							|  |  |  |             self.backbone = build_backbone(config["Backbone"], model_type) | 
					
						
							|  |  |  |             in_channels = self.backbone.out_channels | 
					
						
							| 
									
										
										
										
											2020-11-10 17:18:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         # build neck | 
					
						
							|  |  |  |         # for rec, neck can be cnn,rnn or reshape(None) | 
					
						
							|  |  |  |         # for det, neck can be FPN, BIFPN and so on. | 
					
						
							|  |  |  |         # for cls, neck should be none | 
					
						
							|  |  |  |         if 'Neck' not in config or config['Neck'] is None: | 
					
						
							|  |  |  |             self.use_neck = False | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.use_neck = True | 
					
						
							|  |  |  |             config['Neck']['in_channels'] = in_channels | 
					
						
							|  |  |  |             self.neck = build_neck(config['Neck']) | 
					
						
							|  |  |  |             in_channels = self.neck.out_channels | 
					
						
							| 
									
										
										
										
											2020-11-10 17:18:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-20 17:39:07 +08:00
										 |  |  |         # # build head, head is need for det, rec and cls | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |         if 'Head' not in config or config['Head'] is None: | 
					
						
							|  |  |  |             self.use_head = False | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.use_head = True | 
					
						
							|  |  |  |             config["Head"]['in_channels'] = in_channels | 
					
						
							|  |  |  |             self.head = build_head(config["Head"]) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-03 05:57:31 +00:00
										 |  |  |         self.return_all_feats = config.get("return_all_feats", False) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  |     def forward(self, x, data=None): | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-03 05:57:31 +00:00
										 |  |  |         y = dict() | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         if self.use_transform: | 
					
						
							|  |  |  |             x = self.transform(x) | 
					
						
							| 
									
										
										
										
											2022-08-12 10:49:54 +08:00
										 |  |  |         if self.use_backbone: | 
					
						
							|  |  |  |             x = self.backbone(x) | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         if isinstance(x, dict): | 
					
						
							|  |  |  |             y.update(x) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             y["backbone_out"] = x | 
					
						
							|  |  |  |         final_name = "backbone_out" | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         if self.use_neck: | 
					
						
							|  |  |  |             x = self.neck(x) | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |             if isinstance(x, dict): | 
					
						
							|  |  |  |                 y.update(x) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 y["neck_out"] = x | 
					
						
							|  |  |  |             final_name = "neck_out" | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |         if self.use_head: | 
					
						
							|  |  |  |             x = self.head(x, targets=data) | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |             # for multi head, save ctc neck out for udml | 
					
						
							|  |  |  |             if isinstance(x, dict) and 'ctc_neck' in x.keys(): | 
					
						
							|  |  |  |                 y["neck_out"] = x["ctc_neck"] | 
					
						
							|  |  |  |                 y["head_out"] = x | 
					
						
							|  |  |  |             elif isinstance(x, dict): | 
					
						
							|  |  |  |                 y.update(x) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 y["head_out"] = x | 
					
						
							|  |  |  |             final_name = "head_out" | 
					
						
							| 
									
										
										
										
											2021-06-03 05:57:31 +00:00
										 |  |  |         if self.return_all_feats: | 
					
						
							| 
									
										
										
										
											2022-04-28 18:04:05 +08:00
										 |  |  |             if self.training: | 
					
						
							|  |  |  |                 return y | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |             elif isinstance(x, dict): | 
					
						
							|  |  |  |                 return x | 
					
						
							| 
									
										
										
										
											2022-04-28 18:04:05 +08:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |                 return {final_name: x} | 
					
						
							| 
									
										
										
										
											2021-06-03 05:57:31 +00:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2022-08-12 10:49:54 +08:00
										 |  |  |             return x |