| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #    http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from __future__ import absolute_import | 
					
						
							|  |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from paddle import nn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Im2Seq(nn.Layer): | 
					
						
							|  |  |  |     def __init__(self, in_channels, **kwargs): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.out_channels = in_channels | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							|  |  |  |         B, C, H, W = x.shape | 
					
						
							| 
									
										
										
										
											2020-11-10 17:18:50 +08:00
										 |  |  |         assert H == 1 | 
					
						
							|  |  |  |         x = x.squeeze(axis=2) | 
					
						
							|  |  |  |         x = x.transpose([0, 2, 1])  # (NTC)(batch, width, channels) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class EncoderWithRNN(nn.Layer): | 
					
						
							|  |  |  |     def __init__(self, in_channels, hidden_size): | 
					
						
							|  |  |  |         super(EncoderWithRNN, self).__init__() | 
					
						
							|  |  |  |         self.out_channels = hidden_size * 2 | 
					
						
							|  |  |  |         self.lstm = nn.LSTM( | 
					
						
							|  |  |  |             in_channels, hidden_size, direction='bidirectional', num_layers=2) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							|  |  |  |         x, _ = self.lstm(x) | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class EncoderWithFC(nn.Layer): | 
					
						
							|  |  |  |     def __init__(self, in_channels, hidden_size): | 
					
						
							|  |  |  |         super(EncoderWithFC, self).__init__() | 
					
						
							|  |  |  |         self.out_channels = hidden_size | 
					
						
							|  |  |  |         weight_attr, bias_attr = get_para_bias_attr( | 
					
						
							|  |  |  |             l2_decay=0.00001, k=in_channels, name='reduce_encoder_fea') | 
					
						
							|  |  |  |         self.fc = nn.Linear( | 
					
						
							|  |  |  |             in_channels, | 
					
						
							|  |  |  |             hidden_size, | 
					
						
							|  |  |  |             weight_attr=weight_attr, | 
					
						
							|  |  |  |             bias_attr=bias_attr, | 
					
						
							|  |  |  |             name='reduce_encoder_fea') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							|  |  |  |         x = self.fc(x) | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SequenceEncoder(nn.Layer): | 
					
						
							| 
									
										
										
										
											2020-10-16 16:39:37 +08:00
										 |  |  |     def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         super(SequenceEncoder, self).__init__() | 
					
						
							| 
									
										
										
										
											2020-10-20 16:07:19 +08:00
										 |  |  |         self.encoder_reshape = Im2Seq(in_channels) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         self.out_channels = self.encoder_reshape.out_channels | 
					
						
							|  |  |  |         if encoder_type == 'reshape': | 
					
						
							|  |  |  |             self.only_reshape = True | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             support_encoder_dict = { | 
					
						
							| 
									
										
										
										
											2020-10-20 16:07:19 +08:00
										 |  |  |                 'reshape': Im2Seq, | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |                 'fc': EncoderWithFC, | 
					
						
							|  |  |  |                 'rnn': EncoderWithRNN | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2020-11-10 17:18:50 +08:00
										 |  |  |             assert encoder_type in support_encoder_dict, '{} must in {}'.format( | 
					
						
							|  |  |  |                 encoder_type, support_encoder_dict.keys()) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             self.encoder = support_encoder_dict[encoder_type]( | 
					
						
							|  |  |  |                 self.encoder_reshape.out_channels, hidden_size) | 
					
						
							|  |  |  |             self.out_channels = self.encoder.out_channels | 
					
						
							|  |  |  |             self.only_reshape = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							|  |  |  |         x = self.encoder_reshape(x) | 
					
						
							|  |  |  |         if not self.only_reshape: | 
					
						
							|  |  |  |             x = self.encoder(x) | 
					
						
							|  |  |  |         return x |