mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-11-04 03:39:22 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			411 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			411 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
"""
 | 
						|
This code is refer from: 
 | 
						|
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
 | 
						|
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
 | 
						|
"""
 | 
						|
 | 
						|
from __future__ import absolute_import
 | 
						|
from __future__ import division
 | 
						|
from __future__ import print_function
 | 
						|
 | 
						|
import math
 | 
						|
import paddle
 | 
						|
from paddle import ParamAttr
 | 
						|
import paddle.nn as nn
 | 
						|
import paddle.nn.functional as F
 | 
						|
 | 
						|
 | 
						|
class SAREncoder(nn.Layer):
 | 
						|
    """
 | 
						|
    Args:
 | 
						|
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
 | 
						|
        enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
 | 
						|
        enc_gru (bool): If True, use GRU, else LSTM in encoder.
 | 
						|
        d_model (int): Dim of channels from backbone.
 | 
						|
        d_enc (int): Dim of encoder RNN layer.
 | 
						|
        mask (bool): If True, mask padding in RNN sequence.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self,
 | 
						|
                 enc_bi_rnn=False,
 | 
						|
                 enc_drop_rnn=0.1,
 | 
						|
                 enc_gru=False,
 | 
						|
                 d_model=512,
 | 
						|
                 d_enc=512,
 | 
						|
                 mask=True,
 | 
						|
                 **kwargs):
 | 
						|
        super().__init__()
 | 
						|
        assert isinstance(enc_bi_rnn, bool)
 | 
						|
        assert isinstance(enc_drop_rnn, (int, float))
 | 
						|
        assert 0 <= enc_drop_rnn < 1.0
 | 
						|
        assert isinstance(enc_gru, bool)
 | 
						|
        assert isinstance(d_model, int)
 | 
						|
        assert isinstance(d_enc, int)
 | 
						|
        assert isinstance(mask, bool)
 | 
						|
 | 
						|
        self.enc_bi_rnn = enc_bi_rnn
 | 
						|
        self.enc_drop_rnn = enc_drop_rnn
 | 
						|
        self.mask = mask
 | 
						|
 | 
						|
        # LSTM Encoder
 | 
						|
        if enc_bi_rnn:
 | 
						|
            direction = 'bidirectional'
 | 
						|
        else:
 | 
						|
            direction = 'forward'
 | 
						|
        kwargs = dict(
 | 
						|
            input_size=d_model,
 | 
						|
            hidden_size=d_enc,
 | 
						|
            num_layers=2,
 | 
						|
            time_major=False,
 | 
						|
            dropout=enc_drop_rnn,
 | 
						|
            direction=direction)
 | 
						|
        if enc_gru:
 | 
						|
            self.rnn_encoder = nn.GRU(**kwargs)
 | 
						|
        else:
 | 
						|
            self.rnn_encoder = nn.LSTM(**kwargs)
 | 
						|
 | 
						|
        # global feature transformation
 | 
						|
        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
 | 
						|
        self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
 | 
						|
 | 
						|
    def forward(self, feat, img_metas=None):
 | 
						|
        if img_metas is not None:
 | 
						|
            assert len(img_metas[0]) == feat.shape[0]
 | 
						|
 | 
						|
        valid_ratios = None
 | 
						|
        if img_metas is not None and self.mask:
 | 
						|
            valid_ratios = img_metas[-1]
 | 
						|
 | 
						|
        h_feat = feat.shape[2]  # bsz c h w
 | 
						|
        feat_v = F.max_pool2d(
 | 
						|
            feat, kernel_size=(h_feat, 1), stride=1, padding=0)
 | 
						|
        feat_v = feat_v.squeeze(2)  # bsz * C * W
 | 
						|
        feat_v = paddle.transpose(feat_v, perm=[0, 2, 1])  # bsz * W * C
 | 
						|
        holistic_feat = self.rnn_encoder(feat_v)[0]  # bsz * T * C
 | 
						|
 | 
						|
        if valid_ratios is not None:
 | 
						|
            valid_hf = []
 | 
						|
            T = holistic_feat.shape[1]
 | 
						|
            for i, valid_ratio in enumerate(valid_ratios):
 | 
						|
                valid_step = min(T, math.ceil(T * valid_ratio)) - 1
 | 
						|
                valid_hf.append(holistic_feat[i, valid_step, :])
 | 
						|
            valid_hf = paddle.stack(valid_hf, axis=0)
 | 
						|
        else:
 | 
						|
            valid_hf = holistic_feat[:, -1, :]  # bsz * C
 | 
						|
        holistic_feat = self.linear(valid_hf)  # bsz * C
 | 
						|
 | 
						|
        return holistic_feat
 | 
						|
 | 
						|
 | 
						|
class BaseDecoder(nn.Layer):
 | 
						|
    def __init__(self, **kwargs):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
    def forward_train(self, feat, out_enc, targets, img_metas):
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    def forward_test(self, feat, out_enc, img_metas):
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    def forward(self,
 | 
						|
                feat,
 | 
						|
                out_enc,
 | 
						|
                label=None,
 | 
						|
                img_metas=None,
 | 
						|
                train_mode=True):
 | 
						|
        self.train_mode = train_mode
 | 
						|
 | 
						|
        if train_mode:
 | 
						|
            return self.forward_train(feat, out_enc, label, img_metas)
 | 
						|
        return self.forward_test(feat, out_enc, img_metas)
 | 
						|
 | 
						|
 | 
						|
class ParallelSARDecoder(BaseDecoder):
 | 
						|
    """
 | 
						|
    Args:
 | 
						|
        out_channels (int): Output class number.
 | 
						|
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
 | 
						|
        dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
 | 
						|
        dec_drop_rnn (float): Dropout of RNN layer in decoder.
 | 
						|
        dec_gru (bool): If True, use GRU, else LSTM in decoder.
 | 
						|
        d_model (int): Dim of channels from backbone.
 | 
						|
        d_enc (int): Dim of encoder RNN layer.
 | 
						|
        d_k (int): Dim of channels of attention module.
 | 
						|
        pred_dropout (float): Dropout probability of prediction layer.
 | 
						|
        max_seq_len (int): Maximum sequence length for decoding.
 | 
						|
        mask (bool): If True, mask padding in feature map.
 | 
						|
        start_idx (int): Index of start token.
 | 
						|
        padding_idx (int): Index of padding token.
 | 
						|
        pred_concat (bool): If True, concat glimpse feature from
 | 
						|
            attention with holistic feature and hidden state.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
            self,
 | 
						|
            out_channels,  # 90 + unknown + start + padding
 | 
						|
            enc_bi_rnn=False,
 | 
						|
            dec_bi_rnn=False,
 | 
						|
            dec_drop_rnn=0.0,
 | 
						|
            dec_gru=False,
 | 
						|
            d_model=512,
 | 
						|
            d_enc=512,
 | 
						|
            d_k=64,
 | 
						|
            pred_dropout=0.1,
 | 
						|
            max_text_length=30,
 | 
						|
            mask=True,
 | 
						|
            pred_concat=True,
 | 
						|
            **kwargs):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        self.num_classes = out_channels
 | 
						|
        self.enc_bi_rnn = enc_bi_rnn
 | 
						|
        self.d_k = d_k
 | 
						|
        self.start_idx = out_channels - 2
 | 
						|
        self.padding_idx = out_channels - 1
 | 
						|
        self.max_seq_len = max_text_length
 | 
						|
        self.mask = mask
 | 
						|
        self.pred_concat = pred_concat
 | 
						|
 | 
						|
        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
 | 
						|
        decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
 | 
						|
 | 
						|
        # 2D attention layer
 | 
						|
        self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
 | 
						|
        self.conv3x3_1 = nn.Conv2D(
 | 
						|
            d_model, d_k, kernel_size=3, stride=1, padding=1)
 | 
						|
        self.conv1x1_2 = nn.Linear(d_k, 1)
 | 
						|
 | 
						|
        # Decoder RNN layer
 | 
						|
        if dec_bi_rnn:
 | 
						|
            direction = 'bidirectional'
 | 
						|
        else:
 | 
						|
            direction = 'forward'
 | 
						|
 | 
						|
        kwargs = dict(
 | 
						|
            input_size=encoder_rnn_out_size,
 | 
						|
            hidden_size=encoder_rnn_out_size,
 | 
						|
            num_layers=2,
 | 
						|
            time_major=False,
 | 
						|
            dropout=dec_drop_rnn,
 | 
						|
            direction=direction)
 | 
						|
        if dec_gru:
 | 
						|
            self.rnn_decoder = nn.GRU(**kwargs)
 | 
						|
        else:
 | 
						|
            self.rnn_decoder = nn.LSTM(**kwargs)
 | 
						|
 | 
						|
        # Decoder input embedding
 | 
						|
        self.embedding = nn.Embedding(
 | 
						|
            self.num_classes,
 | 
						|
            encoder_rnn_out_size,
 | 
						|
            padding_idx=self.padding_idx)
 | 
						|
 | 
						|
        # Prediction layer
 | 
						|
        self.pred_dropout = nn.Dropout(pred_dropout)
 | 
						|
        pred_num_classes = self.num_classes - 1
 | 
						|
        if pred_concat:
 | 
						|
            fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
 | 
						|
        else:
 | 
						|
            fc_in_channel = d_model
 | 
						|
        self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
 | 
						|
 | 
						|
    def _2d_attention(self,
 | 
						|
                      decoder_input,
 | 
						|
                      feat,
 | 
						|
                      holistic_feat,
 | 
						|
                      valid_ratios=None):
 | 
						|
 | 
						|
        y = self.rnn_decoder(decoder_input)[0]
 | 
						|
        # y: bsz * (seq_len + 1) * hidden_size
 | 
						|
 | 
						|
        attn_query = self.conv1x1_1(y)  # bsz * (seq_len + 1) * attn_size
 | 
						|
        bsz, seq_len, attn_size = attn_query.shape
 | 
						|
        attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
 | 
						|
        # (bsz, seq_len + 1, attn_size, 1, 1)
 | 
						|
 | 
						|
        attn_key = self.conv3x3_1(feat)
 | 
						|
        # bsz * attn_size * h * w
 | 
						|
        attn_key = attn_key.unsqueeze(1)
 | 
						|
        # bsz * 1 * attn_size * h * w
 | 
						|
 | 
						|
        attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
 | 
						|
 | 
						|
        # bsz * (seq_len + 1) * attn_size * h * w
 | 
						|
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
 | 
						|
        # bsz * (seq_len + 1) * h * w * attn_size
 | 
						|
        attn_weight = self.conv1x1_2(attn_weight)
 | 
						|
        # bsz * (seq_len + 1) * h * w * 1
 | 
						|
        bsz, T, h, w, c = attn_weight.shape
 | 
						|
        assert c == 1
 | 
						|
 | 
						|
        if valid_ratios is not None:
 | 
						|
            # cal mask of attention weight
 | 
						|
            for i, valid_ratio in enumerate(valid_ratios):
 | 
						|
                valid_width = min(w, math.ceil(w * valid_ratio))
 | 
						|
                if valid_width < w:
 | 
						|
                    attn_weight[i, :, :, valid_width:, :] = float('-inf')
 | 
						|
 | 
						|
        attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
 | 
						|
        attn_weight = F.softmax(attn_weight, axis=-1)
 | 
						|
 | 
						|
        attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
 | 
						|
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
 | 
						|
        # attn_weight: bsz * T * c * h * w
 | 
						|
        # feat: bsz * c * h * w
 | 
						|
        attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
 | 
						|
                               (3, 4),
 | 
						|
                               keepdim=False)
 | 
						|
        # bsz * (seq_len + 1) * C
 | 
						|
 | 
						|
        # Linear transformation
 | 
						|
        if self.pred_concat:
 | 
						|
            hf_c = holistic_feat.shape[-1]
 | 
						|
            holistic_feat = paddle.expand(
 | 
						|
                holistic_feat, shape=[bsz, seq_len, hf_c])
 | 
						|
            y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
 | 
						|
        else:
 | 
						|
            y = self.prediction(attn_feat)
 | 
						|
        # bsz * (seq_len + 1) * num_classes
 | 
						|
        if self.train_mode:
 | 
						|
            y = self.pred_dropout(y)
 | 
						|
 | 
						|
        return y
 | 
						|
 | 
						|
    def forward_train(self, feat, out_enc, label, img_metas):
 | 
						|
        '''
 | 
						|
        img_metas: [label, valid_ratio]
 | 
						|
        '''
 | 
						|
        if img_metas is not None:
 | 
						|
            assert len(img_metas[0]) == feat.shape[0]
 | 
						|
 | 
						|
        valid_ratios = None
 | 
						|
        if img_metas is not None and self.mask:
 | 
						|
            valid_ratios = img_metas[-1]
 | 
						|
 | 
						|
        lab_embedding = self.embedding(label)
 | 
						|
        # bsz * seq_len * emb_dim
 | 
						|
        out_enc = out_enc.unsqueeze(1)
 | 
						|
        # bsz * 1 * emb_dim
 | 
						|
        in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
 | 
						|
        # bsz * (seq_len + 1) * C
 | 
						|
        out_dec = self._2d_attention(
 | 
						|
            in_dec, feat, out_enc, valid_ratios=valid_ratios)
 | 
						|
        # bsz * (seq_len + 1) * num_classes
 | 
						|
 | 
						|
        return out_dec[:, 1:, :]  # bsz * seq_len * num_classes
 | 
						|
 | 
						|
    def forward_test(self, feat, out_enc, img_metas):
 | 
						|
        if img_metas is not None:
 | 
						|
            assert len(img_metas[0]) == feat.shape[0]
 | 
						|
 | 
						|
        valid_ratios = None
 | 
						|
        if img_metas is not None and self.mask:
 | 
						|
            valid_ratios = img_metas[-1]
 | 
						|
 | 
						|
        seq_len = self.max_seq_len
 | 
						|
        bsz = feat.shape[0]
 | 
						|
        start_token = paddle.full(
 | 
						|
            (bsz, ), fill_value=self.start_idx, dtype='int64')
 | 
						|
        # bsz
 | 
						|
        start_token = self.embedding(start_token)
 | 
						|
        # bsz * emb_dim
 | 
						|
        emb_dim = start_token.shape[1]
 | 
						|
        start_token = start_token.unsqueeze(1)
 | 
						|
        start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
 | 
						|
        # bsz * seq_len * emb_dim
 | 
						|
        out_enc = out_enc.unsqueeze(1)
 | 
						|
        # bsz * 1 * emb_dim
 | 
						|
        decoder_input = paddle.concat((out_enc, start_token), axis=1)
 | 
						|
        # bsz * (seq_len + 1) * emb_dim
 | 
						|
 | 
						|
        outputs = []
 | 
						|
        for i in range(1, seq_len + 1):
 | 
						|
            decoder_output = self._2d_attention(
 | 
						|
                decoder_input, feat, out_enc, valid_ratios=valid_ratios)
 | 
						|
            char_output = decoder_output[:, i, :]  # bsz * num_classes
 | 
						|
            char_output = F.softmax(char_output, -1)
 | 
						|
            outputs.append(char_output)
 | 
						|
            max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
 | 
						|
            char_embedding = self.embedding(max_idx)  # bsz * emb_dim
 | 
						|
            if i < seq_len:
 | 
						|
                decoder_input[:, i + 1, :] = char_embedding
 | 
						|
 | 
						|
        outputs = paddle.stack(outputs, 1)  # bsz * seq_len * num_classes
 | 
						|
 | 
						|
        return outputs
 | 
						|
 | 
						|
 | 
						|
class SARHead(nn.Layer):
 | 
						|
    def __init__(self,
 | 
						|
                 in_channels,
 | 
						|
                 out_channels,
 | 
						|
                 enc_dim=512,
 | 
						|
                 max_text_length=30,
 | 
						|
                 enc_bi_rnn=False,
 | 
						|
                 enc_drop_rnn=0.1,
 | 
						|
                 enc_gru=False,
 | 
						|
                 dec_bi_rnn=False,
 | 
						|
                 dec_drop_rnn=0.0,
 | 
						|
                 dec_gru=False,
 | 
						|
                 d_k=512,
 | 
						|
                 pred_dropout=0.1,
 | 
						|
                 pred_concat=True,
 | 
						|
                 **kwargs):
 | 
						|
        super(SARHead, self).__init__()
 | 
						|
 | 
						|
        # encoder module
 | 
						|
        self.encoder = SAREncoder(
 | 
						|
            enc_bi_rnn=enc_bi_rnn,
 | 
						|
            enc_drop_rnn=enc_drop_rnn,
 | 
						|
            enc_gru=enc_gru,
 | 
						|
            d_model=in_channels,
 | 
						|
            d_enc=enc_dim)
 | 
						|
 | 
						|
        # decoder module
 | 
						|
        self.decoder = ParallelSARDecoder(
 | 
						|
            out_channels=out_channels,
 | 
						|
            enc_bi_rnn=enc_bi_rnn,
 | 
						|
            dec_bi_rnn=dec_bi_rnn,
 | 
						|
            dec_drop_rnn=dec_drop_rnn,
 | 
						|
            dec_gru=dec_gru,
 | 
						|
            d_model=in_channels,
 | 
						|
            d_enc=enc_dim,
 | 
						|
            d_k=d_k,
 | 
						|
            pred_dropout=pred_dropout,
 | 
						|
            max_text_length=max_text_length,
 | 
						|
            pred_concat=pred_concat)
 | 
						|
 | 
						|
    def forward(self, feat, targets=None):
 | 
						|
        '''
 | 
						|
        img_metas: [label, valid_ratio]
 | 
						|
        '''
 | 
						|
        holistic_feat = self.encoder(feat, targets)  # bsz c
 | 
						|
 | 
						|
        if self.training:
 | 
						|
            label = targets[0]  # label
 | 
						|
            label = paddle.to_tensor(label, dtype='int64')
 | 
						|
            final_out = self.decoder(
 | 
						|
                feat, holistic_feat, label, img_metas=targets)
 | 
						|
        else:
 | 
						|
            final_out = self.decoder(
 | 
						|
                feat,
 | 
						|
                holistic_feat,
 | 
						|
                label=None,
 | 
						|
                img_metas=targets,
 | 
						|
                train_mode=False)
 | 
						|
            # (bsz, seq_len, num_classes)
 | 
						|
 | 
						|
        return final_out
 |