| 
									
										
										
										
											2021-11-04 09:50:41 +00:00
										 |  |  | # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #    http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | This code is refer from:  | 
					
						
							|  |  |  | https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py | 
					
						
							|  |  |  | https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  | from __future__ import absolute_import | 
					
						
							|  |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import math | 
					
						
							|  |  |  | import paddle | 
					
						
							|  |  |  | from paddle import ParamAttr | 
					
						
							|  |  |  | import paddle.nn as nn | 
					
						
							|  |  |  | import paddle.nn.functional as F | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SAREncoder(nn.Layer): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Args: | 
					
						
							|  |  |  |         enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. | 
					
						
							|  |  |  |         enc_drop_rnn (float): Dropout probability of RNN layer in encoder. | 
					
						
							|  |  |  |         enc_gru (bool): If True, use GRU, else LSTM in encoder. | 
					
						
							|  |  |  |         d_model (int): Dim of channels from backbone. | 
					
						
							|  |  |  |         d_enc (int): Dim of encoder RNN layer. | 
					
						
							|  |  |  |         mask (bool): If True, mask padding in RNN sequence. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  enc_bi_rnn=False, | 
					
						
							|  |  |  |                  enc_drop_rnn=0.1, | 
					
						
							|  |  |  |                  enc_gru=False, | 
					
						
							|  |  |  |                  d_model=512, | 
					
						
							|  |  |  |                  d_enc=512, | 
					
						
							|  |  |  |                  mask=True, | 
					
						
							|  |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         assert isinstance(enc_bi_rnn, bool) | 
					
						
							|  |  |  |         assert isinstance(enc_drop_rnn, (int, float)) | 
					
						
							|  |  |  |         assert 0 <= enc_drop_rnn < 1.0 | 
					
						
							|  |  |  |         assert isinstance(enc_gru, bool) | 
					
						
							|  |  |  |         assert isinstance(d_model, int) | 
					
						
							|  |  |  |         assert isinstance(d_enc, int) | 
					
						
							|  |  |  |         assert isinstance(mask, bool) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.enc_bi_rnn = enc_bi_rnn | 
					
						
							|  |  |  |         self.enc_drop_rnn = enc_drop_rnn | 
					
						
							|  |  |  |         self.mask = mask | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # LSTM Encoder | 
					
						
							|  |  |  |         if enc_bi_rnn: | 
					
						
							|  |  |  |             direction = 'bidirectional' | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             direction = 'forward' | 
					
						
							|  |  |  |         kwargs = dict( | 
					
						
							|  |  |  |             input_size=d_model, | 
					
						
							|  |  |  |             hidden_size=d_enc, | 
					
						
							|  |  |  |             num_layers=2, | 
					
						
							|  |  |  |             time_major=False, | 
					
						
							|  |  |  |             dropout=enc_drop_rnn, | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             direction=direction) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         if enc_gru: | 
					
						
							|  |  |  |             self.rnn_encoder = nn.GRU(**kwargs) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.rnn_encoder = nn.LSTM(**kwargs) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         # global feature transformation | 
					
						
							|  |  |  |         encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) | 
					
						
							|  |  |  |         self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |     def forward(self, feat, img_metas=None): | 
					
						
							|  |  |  |         if img_metas is not None: | 
					
						
							| 
									
										
										
										
											2022-05-18 19:03:18 +08:00
										 |  |  |             assert len(img_metas[0]) == paddle.shape(feat)[0] | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         valid_ratios = None | 
					
						
							|  |  |  |         if img_metas is not None and self.mask: | 
					
						
							|  |  |  |             valid_ratios = img_metas[-1] | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         h_feat = feat.shape[2]  # bsz c h w | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         feat_v = F.max_pool2d( | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             feat, kernel_size=(h_feat, 1), stride=1, padding=0) | 
					
						
							|  |  |  |         feat_v = feat_v.squeeze(2)  # bsz * C * W | 
					
						
							|  |  |  |         feat_v = paddle.transpose(feat_v, perm=[0, 2, 1])  # bsz * W * C | 
					
						
							|  |  |  |         holistic_feat = self.rnn_encoder(feat_v)[0]  # bsz * T * C | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         if valid_ratios is not None: | 
					
						
							|  |  |  |             valid_hf = [] | 
					
						
							| 
									
										
										
										
											2022-05-18 19:03:18 +08:00
										 |  |  |             T = paddle.shape(holistic_feat)[1] | 
					
						
							|  |  |  |             for i in range(paddle.shape(valid_ratios)[0]): | 
					
						
							|  |  |  |                 valid_step = paddle.minimum( | 
					
						
							|  |  |  |                     T, paddle.ceil(valid_ratios[i] * T).astype('int32')) - 1 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |                 valid_hf.append(holistic_feat[i, valid_step, :]) | 
					
						
							|  |  |  |             valid_hf = paddle.stack(valid_hf, axis=0) | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             valid_hf = holistic_feat[:, -1, :]  # bsz * C | 
					
						
							|  |  |  |         holistic_feat = self.linear(valid_hf)  # bsz * C | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         return holistic_feat | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | class BaseDecoder(nn.Layer): | 
					
						
							|  |  |  |     def __init__(self, **kwargs): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward_train(self, feat, out_enc, targets, img_metas): | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward_test(self, feat, out_enc, img_metas): | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |     def forward(self, | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |                 feat, | 
					
						
							|  |  |  |                 out_enc, | 
					
						
							|  |  |  |                 label=None, | 
					
						
							|  |  |  |                 img_metas=None, | 
					
						
							|  |  |  |                 train_mode=True): | 
					
						
							|  |  |  |         self.train_mode = train_mode | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if train_mode: | 
					
						
							|  |  |  |             return self.forward_train(feat, out_enc, label, img_metas) | 
					
						
							|  |  |  |         return self.forward_test(feat, out_enc, img_metas) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ParallelSARDecoder(BaseDecoder): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Args: | 
					
						
							| 
									
										
										
										
											2021-09-07 03:33:02 +00:00
										 |  |  |         out_channels (int): Output class number. | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. | 
					
						
							|  |  |  |         dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. | 
					
						
							|  |  |  |         dec_drop_rnn (float): Dropout of RNN layer in decoder. | 
					
						
							|  |  |  |         dec_gru (bool): If True, use GRU, else LSTM in decoder. | 
					
						
							|  |  |  |         d_model (int): Dim of channels from backbone. | 
					
						
							|  |  |  |         d_enc (int): Dim of encoder RNN layer. | 
					
						
							|  |  |  |         d_k (int): Dim of channels of attention module. | 
					
						
							|  |  |  |         pred_dropout (float): Dropout probability of prediction layer. | 
					
						
							|  |  |  |         max_seq_len (int): Maximum sequence length for decoding. | 
					
						
							|  |  |  |         mask (bool): If True, mask padding in feature map. | 
					
						
							|  |  |  |         start_idx (int): Index of start token. | 
					
						
							|  |  |  |         padding_idx (int): Index of padding token. | 
					
						
							|  |  |  |         pred_concat (bool): If True, concat glimpse feature from | 
					
						
							|  |  |  |             attention with holistic feature and hidden state. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |             self, | 
					
						
							|  |  |  |             out_channels,  # 90 + unknown + start + padding | 
					
						
							|  |  |  |             enc_bi_rnn=False, | 
					
						
							|  |  |  |             dec_bi_rnn=False, | 
					
						
							|  |  |  |             dec_drop_rnn=0.0, | 
					
						
							|  |  |  |             dec_gru=False, | 
					
						
							|  |  |  |             d_model=512, | 
					
						
							|  |  |  |             d_enc=512, | 
					
						
							|  |  |  |             d_k=64, | 
					
						
							|  |  |  |             pred_dropout=0.1, | 
					
						
							|  |  |  |             max_text_length=30, | 
					
						
							|  |  |  |             mask=True, | 
					
						
							|  |  |  |             pred_concat=True, | 
					
						
							|  |  |  |             **kwargs): | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-07 03:38:34 +00:00
										 |  |  |         self.num_classes = out_channels | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         self.enc_bi_rnn = enc_bi_rnn | 
					
						
							|  |  |  |         self.d_k = d_k | 
					
						
							| 
									
										
										
										
											2021-09-07 03:33:02 +00:00
										 |  |  |         self.start_idx = out_channels - 2 | 
					
						
							| 
									
										
										
										
											2021-09-07 03:38:34 +00:00
										 |  |  |         self.padding_idx = out_channels - 1 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         self.max_seq_len = max_text_length | 
					
						
							|  |  |  |         self.mask = mask | 
					
						
							|  |  |  |         self.pred_concat = pred_concat | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) | 
					
						
							|  |  |  |         decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # 2D attention layer | 
					
						
							|  |  |  |         self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |         self.conv3x3_1 = nn.Conv2D( | 
					
						
							|  |  |  |             d_model, d_k, kernel_size=3, stride=1, padding=1) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         self.conv1x1_2 = nn.Linear(d_k, 1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Decoder RNN layer | 
					
						
							|  |  |  |         if dec_bi_rnn: | 
					
						
							|  |  |  |             direction = 'bidirectional' | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             direction = 'forward' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         kwargs = dict( | 
					
						
							|  |  |  |             input_size=encoder_rnn_out_size, | 
					
						
							|  |  |  |             hidden_size=encoder_rnn_out_size, | 
					
						
							|  |  |  |             num_layers=2, | 
					
						
							|  |  |  |             time_major=False, | 
					
						
							|  |  |  |             dropout=dec_drop_rnn, | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             direction=direction) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         if dec_gru: | 
					
						
							|  |  |  |             self.rnn_decoder = nn.GRU(**kwargs) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.rnn_decoder = nn.LSTM(**kwargs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Decoder input embedding | 
					
						
							|  |  |  |         self.embedding = nn.Embedding( | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             self.num_classes, | 
					
						
							|  |  |  |             encoder_rnn_out_size, | 
					
						
							|  |  |  |             padding_idx=self.padding_idx) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         # Prediction layer | 
					
						
							|  |  |  |         self.pred_dropout = nn.Dropout(pred_dropout) | 
					
						
							| 
									
										
										
										
											2021-09-07 07:39:24 +00:00
										 |  |  |         pred_num_classes = self.num_classes - 1 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         if pred_concat: | 
					
						
							| 
									
										
										
										
											2022-04-02 16:28:16 +08:00
										 |  |  |             fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         else: | 
					
						
							|  |  |  |             fc_in_channel = d_model | 
					
						
							|  |  |  |         self.prediction = nn.Linear(fc_in_channel, pred_num_classes) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _2d_attention(self, | 
					
						
							|  |  |  |                       decoder_input, | 
					
						
							|  |  |  |                       feat, | 
					
						
							|  |  |  |                       holistic_feat, | 
					
						
							|  |  |  |                       valid_ratios=None): | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         y = self.rnn_decoder(decoder_input)[0] | 
					
						
							|  |  |  |         # y: bsz * (seq_len + 1) * hidden_size | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         attn_query = self.conv1x1_1(y)  # bsz * (seq_len + 1) * attn_size | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         bsz, seq_len, attn_size = attn_query.shape | 
					
						
							|  |  |  |         attn_query = paddle.unsqueeze(attn_query, axis=[3, 4]) | 
					
						
							|  |  |  |         # (bsz, seq_len + 1, attn_size, 1, 1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         attn_key = self.conv3x3_1(feat) | 
					
						
							|  |  |  |         # bsz * attn_size * h * w | 
					
						
							|  |  |  |         attn_key = attn_key.unsqueeze(1) | 
					
						
							|  |  |  |         # bsz * 1 * attn_size * h * w | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         attn_weight = paddle.tanh(paddle.add(attn_key, attn_query)) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         # bsz * (seq_len + 1) * attn_size * h * w | 
					
						
							|  |  |  |         attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2]) | 
					
						
							|  |  |  |         # bsz * (seq_len + 1) * h * w * attn_size | 
					
						
							|  |  |  |         attn_weight = self.conv1x1_2(attn_weight) | 
					
						
							|  |  |  |         # bsz * (seq_len + 1) * h * w * 1 | 
					
						
							| 
									
										
										
										
											2022-05-18 19:03:18 +08:00
										 |  |  |         bsz, T, h, w, c = paddle.shape(attn_weight) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         assert c == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if valid_ratios is not None: | 
					
						
							|  |  |  |             # cal mask of attention weight | 
					
						
							| 
									
										
										
										
											2022-05-18 19:03:18 +08:00
										 |  |  |             for i in range(paddle.shape(valid_ratios)[0]): | 
					
						
							|  |  |  |                 valid_width = paddle.minimum( | 
					
						
							|  |  |  |                     w, paddle.ceil(valid_ratios[i] * w).astype("int32")) | 
					
						
							| 
									
										
										
										
											2021-09-22 02:46:02 +00:00
										 |  |  |                 if valid_width < w: | 
					
						
							|  |  |  |                     attn_weight[i, :, :, valid_width:, :] = float('-inf') | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         attn_weight = paddle.reshape(attn_weight, [bsz, T, -1]) | 
					
						
							|  |  |  |         attn_weight = F.softmax(attn_weight, axis=-1) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c]) | 
					
						
							|  |  |  |         attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3]) | 
					
						
							|  |  |  |         # attn_weight: bsz * T * c * h * w | 
					
						
							|  |  |  |         # feat: bsz * c * h * w | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |         attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight), | 
					
						
							|  |  |  |                                (3, 4), | 
					
						
							|  |  |  |                                keepdim=False) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         # bsz * (seq_len + 1) * C | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Linear transformation | 
					
						
							|  |  |  |         if self.pred_concat: | 
					
						
							|  |  |  |             hf_c = holistic_feat.shape[-1] | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             holistic_feat = paddle.expand( | 
					
						
							|  |  |  |                 holistic_feat, shape=[bsz, seq_len, hf_c]) | 
					
						
							| 
									
										
										
										
											2023-04-21 13:53:09 +08:00
										 |  |  |             y = self.prediction( | 
					
						
							|  |  |  |                 paddle.concat((y, attn_feat.astype(y.dtype), | 
					
						
							|  |  |  |                                holistic_feat.astype(y.dtype)), 2)) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         else: | 
					
						
							|  |  |  |             y = self.prediction(attn_feat) | 
					
						
							|  |  |  |         # bsz * (seq_len + 1) * num_classes | 
					
						
							|  |  |  |         if self.train_mode: | 
					
						
							|  |  |  |             y = self.pred_dropout(y) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         return y | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward_train(self, feat, out_enc, label, img_metas): | 
					
						
							|  |  |  |         '''
 | 
					
						
							|  |  |  |         img_metas: [label, valid_ratio] | 
					
						
							|  |  |  |         '''
 | 
					
						
							|  |  |  |         if img_metas is not None: | 
					
						
							| 
									
										
										
										
											2022-05-18 19:03:18 +08:00
										 |  |  |             assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0] | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         valid_ratios = None | 
					
						
							|  |  |  |         if img_metas is not None and self.mask: | 
					
						
							|  |  |  |             valid_ratios = img_metas[-1] | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         lab_embedding = self.embedding(label) | 
					
						
							|  |  |  |         # bsz * seq_len * emb_dim | 
					
						
							| 
									
										
										
										
											2023-04-21 13:53:09 +08:00
										 |  |  |         out_enc = out_enc.unsqueeze(1).astype(lab_embedding.dtype) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         # bsz * 1 * emb_dim | 
					
						
							|  |  |  |         in_dec = paddle.concat((out_enc, lab_embedding), axis=1) | 
					
						
							|  |  |  |         # bsz * (seq_len + 1) * C | 
					
						
							|  |  |  |         out_dec = self._2d_attention( | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             in_dec, feat, out_enc, valid_ratios=valid_ratios) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return out_dec[:, 1:, :]  # bsz * seq_len * num_classes | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def forward_test(self, feat, out_enc, img_metas): | 
					
						
							|  |  |  |         if img_metas is not None: | 
					
						
							|  |  |  |             assert len(img_metas[0]) == feat.shape[0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         valid_ratios = None | 
					
						
							|  |  |  |         if img_metas is not None and self.mask: | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             valid_ratios = img_metas[-1] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         seq_len = self.max_seq_len | 
					
						
							|  |  |  |         bsz = feat.shape[0] | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |         start_token = paddle.full( | 
					
						
							|  |  |  |             (bsz, ), fill_value=self.start_idx, dtype='int64') | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         # bsz | 
					
						
							|  |  |  |         start_token = self.embedding(start_token) | 
					
						
							|  |  |  |         # bsz * emb_dim | 
					
						
							|  |  |  |         emb_dim = start_token.shape[1] | 
					
						
							|  |  |  |         start_token = start_token.unsqueeze(1) | 
					
						
							|  |  |  |         start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim]) | 
					
						
							|  |  |  |         # bsz * seq_len * emb_dim | 
					
						
							|  |  |  |         out_enc = out_enc.unsqueeze(1) | 
					
						
							|  |  |  |         # bsz * 1 * emb_dim | 
					
						
							|  |  |  |         decoder_input = paddle.concat((out_enc, start_token), axis=1) | 
					
						
							|  |  |  |         # bsz * (seq_len + 1) * emb_dim | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         outputs = [] | 
					
						
							|  |  |  |         for i in range(1, seq_len + 1): | 
					
						
							|  |  |  |             decoder_output = self._2d_attention( | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |                 decoder_input, feat, out_enc, valid_ratios=valid_ratios) | 
					
						
							|  |  |  |             char_output = decoder_output[:, i, :]  # bsz * num_classes | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |             char_output = F.softmax(char_output, -1) | 
					
						
							|  |  |  |             outputs.append(char_output) | 
					
						
							|  |  |  |             max_idx = paddle.argmax(char_output, axis=1, keepdim=False) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             char_embedding = self.embedding(max_idx)  # bsz * emb_dim | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |             if i < seq_len: | 
					
						
							|  |  |  |                 decoder_input[:, i + 1, :] = char_embedding | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         outputs = paddle.stack(outputs, 1)  # bsz * seq_len * num_classes | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return outputs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SARHead(nn.Layer): | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |     def __init__(self, | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  |                  in_channels, | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |                  out_channels, | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  |                  enc_dim=512, | 
					
						
							|  |  |  |                  max_text_length=30, | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |                  enc_bi_rnn=False, | 
					
						
							|  |  |  |                  enc_drop_rnn=0.1, | 
					
						
							|  |  |  |                  enc_gru=False, | 
					
						
							|  |  |  |                  dec_bi_rnn=False, | 
					
						
							|  |  |  |                  dec_drop_rnn=0.0, | 
					
						
							|  |  |  |                  dec_gru=False, | 
					
						
							|  |  |  |                  d_k=512, | 
					
						
							|  |  |  |                  pred_dropout=0.1, | 
					
						
							|  |  |  |                  pred_concat=True, | 
					
						
							|  |  |  |                  **kwargs): | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         super(SARHead, self).__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # encoder module | 
					
						
							|  |  |  |         self.encoder = SAREncoder( | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  |             enc_bi_rnn=enc_bi_rnn, | 
					
						
							|  |  |  |             enc_drop_rnn=enc_drop_rnn, | 
					
						
							|  |  |  |             enc_gru=enc_gru, | 
					
						
							|  |  |  |             d_model=in_channels, | 
					
						
							|  |  |  |             d_enc=enc_dim) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # decoder module | 
					
						
							|  |  |  |         self.decoder = ParallelSARDecoder( | 
					
						
							| 
									
										
										
										
											2021-09-07 03:33:02 +00:00
										 |  |  |             out_channels=out_channels, | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             enc_bi_rnn=enc_bi_rnn, | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |             dec_bi_rnn=dec_bi_rnn, | 
					
						
							|  |  |  |             dec_drop_rnn=dec_drop_rnn, | 
					
						
							|  |  |  |             dec_gru=dec_gru, | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  |             d_model=in_channels, | 
					
						
							|  |  |  |             d_enc=enc_dim, | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |             d_k=d_k, | 
					
						
							|  |  |  |             pred_dropout=pred_dropout, | 
					
						
							|  |  |  |             max_text_length=max_text_length, | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             pred_concat=pred_concat) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |     def forward(self, feat, targets=None): | 
					
						
							|  |  |  |         '''
 | 
					
						
							|  |  |  |         img_metas: [label, valid_ratio] | 
					
						
							|  |  |  |         '''
 | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |         holistic_feat = self.encoder(feat, targets)  # bsz c | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         if self.training: | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             label = targets[0]  # label | 
					
						
							|  |  |  |             final_out = self.decoder( | 
					
						
							|  |  |  |                 feat, holistic_feat, label, img_metas=targets) | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  |             final_out = self.decoder( | 
					
						
							|  |  |  |                 feat, | 
					
						
							|  |  |  |                 holistic_feat, | 
					
						
							|  |  |  |                 label=None, | 
					
						
							|  |  |  |                 img_metas=targets, | 
					
						
							|  |  |  |                 train_mode=False) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |             # (bsz, seq_len, num_classes) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:13:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:45:59 +00:00
										 |  |  |         return final_out |