mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-29 07:58:41 +00:00
fix code style
This commit is contained in:
parent
ae09ef607f
commit
d611515803
@ -19,6 +19,7 @@ class SAREncoder(nn.Layer):
|
||||
d_enc (int): Dim of encoder RNN layer.
|
||||
mask (bool): If True, mask padding in RNN sequence.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
enc_bi_rnn=False,
|
||||
enc_drop_rnn=0.1,
|
||||
@ -51,33 +52,31 @@ class SAREncoder(nn.Layer):
|
||||
num_layers=2,
|
||||
time_major=False,
|
||||
dropout=enc_drop_rnn,
|
||||
direction=direction
|
||||
)
|
||||
direction=direction)
|
||||
if enc_gru:
|
||||
self.rnn_encoder = nn.GRU(**kwargs)
|
||||
else:
|
||||
self.rnn_encoder = nn.LSTM(**kwargs)
|
||||
|
||||
|
||||
# global feature transformation
|
||||
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
||||
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
|
||||
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
if img_metas is not None:
|
||||
assert len(img_metas[0]) == feat.shape[0]
|
||||
|
||||
|
||||
valid_ratios = None
|
||||
if img_metas is not None and self.mask:
|
||||
valid_ratios = img_metas[-1]
|
||||
|
||||
h_feat = feat.shape[2] # bsz c h w
|
||||
|
||||
h_feat = feat.shape[2] # bsz c h w
|
||||
feat_v = F.max_pool2d(
|
||||
feat, kernel_size=(h_feat, 1), stride=1, padding=0
|
||||
)
|
||||
feat_v = feat_v.squeeze(2) # bsz * C * W
|
||||
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
|
||||
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
|
||||
|
||||
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
|
||||
feat_v = feat_v.squeeze(2) # bsz * C * W
|
||||
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
|
||||
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
|
||||
|
||||
if valid_ratios is not None:
|
||||
valid_hf = []
|
||||
T = holistic_feat.shape[1]
|
||||
@ -86,11 +85,11 @@ class SAREncoder(nn.Layer):
|
||||
valid_hf.append(holistic_feat[i, valid_step, :])
|
||||
valid_hf = paddle.stack(valid_hf, axis=0)
|
||||
else:
|
||||
valid_hf = holistic_feat[:, -1, :] # bsz * C
|
||||
holistic_feat = self.linear(valid_hf) # bsz * C
|
||||
|
||||
valid_hf = holistic_feat[:, -1, :] # bsz * C
|
||||
holistic_feat = self.linear(valid_hf) # bsz * C
|
||||
|
||||
return holistic_feat
|
||||
|
||||
|
||||
|
||||
class BaseDecoder(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
@ -102,7 +101,7 @@ class BaseDecoder(nn.Layer):
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self,
|
||||
def forward(self,
|
||||
feat,
|
||||
out_enc,
|
||||
label=None,
|
||||
@ -135,20 +134,21 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
attention with holistic feature and hidden state.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_channels, # 90 + unknown + start + padding
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_drop_rnn=0.0,
|
||||
dec_gru=False,
|
||||
d_model=512,
|
||||
d_enc=512,
|
||||
d_k=64,
|
||||
pred_dropout=0.1,
|
||||
max_text_length=30,
|
||||
mask=True,
|
||||
pred_concat=True,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels, # 90 + unknown + start + padding
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_drop_rnn=0.0,
|
||||
dec_gru=False,
|
||||
d_model=512,
|
||||
d_enc=512,
|
||||
d_k=64,
|
||||
pred_dropout=0.1,
|
||||
max_text_length=30,
|
||||
mask=True,
|
||||
pred_concat=True,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = out_channels
|
||||
@ -165,7 +165,8 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
|
||||
# 2D attention layer
|
||||
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
|
||||
self.conv3x3_1 = nn.Conv2D(d_model, d_k, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3x3_1 = nn.Conv2D(
|
||||
d_model, d_k, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1x1_2 = nn.Linear(d_k, 1)
|
||||
|
||||
# Decoder RNN layer
|
||||
@ -180,8 +181,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
num_layers=2,
|
||||
time_major=False,
|
||||
dropout=dec_drop_rnn,
|
||||
direction=direction
|
||||
)
|
||||
direction=direction)
|
||||
if dec_gru:
|
||||
self.rnn_decoder = nn.GRU(**kwargs)
|
||||
else:
|
||||
@ -189,8 +189,10 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
|
||||
# Decoder input embedding
|
||||
self.embedding = nn.Embedding(
|
||||
self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx)
|
||||
|
||||
self.num_classes,
|
||||
encoder_rnn_out_size,
|
||||
padding_idx=self.padding_idx)
|
||||
|
||||
# Prediction layer
|
||||
self.pred_dropout = nn.Dropout(pred_dropout)
|
||||
pred_num_classes = num_classes - 1
|
||||
@ -205,11 +207,11 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
feat,
|
||||
holistic_feat,
|
||||
valid_ratios=None):
|
||||
|
||||
|
||||
y = self.rnn_decoder(decoder_input)[0]
|
||||
# y: bsz * (seq_len + 1) * hidden_size
|
||||
|
||||
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
|
||||
|
||||
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
|
||||
bsz, seq_len, attn_size = attn_query.shape
|
||||
attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
|
||||
# (bsz, seq_len + 1, attn_size, 1, 1)
|
||||
@ -220,7 +222,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
# bsz * 1 * attn_size * h * w
|
||||
|
||||
attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
|
||||
|
||||
|
||||
# bsz * (seq_len + 1) * attn_size * h * w
|
||||
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
|
||||
# bsz * (seq_len + 1) * h * w * attn_size
|
||||
@ -237,25 +239,28 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
|
||||
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
|
||||
attn_weight = F.softmax(attn_weight, axis=-1)
|
||||
|
||||
|
||||
attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
|
||||
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
|
||||
# attn_weight: bsz * T * c * h * w
|
||||
# feat: bsz * c * h * w
|
||||
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False)
|
||||
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
|
||||
(3, 4),
|
||||
keepdim=False)
|
||||
# bsz * (seq_len + 1) * C
|
||||
|
||||
# Linear transformation
|
||||
if self.pred_concat:
|
||||
hf_c = holistic_feat.shape[-1]
|
||||
holistic_feat = paddle.expand(holistic_feat, shape=[bsz, seq_len, hf_c])
|
||||
holistic_feat = paddle.expand(
|
||||
holistic_feat, shape=[bsz, seq_len, hf_c])
|
||||
y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
|
||||
else:
|
||||
y = self.prediction(attn_feat)
|
||||
# bsz * (seq_len + 1) * num_classes
|
||||
if self.train_mode:
|
||||
y = self.pred_dropout(y)
|
||||
|
||||
|
||||
return y
|
||||
|
||||
def forward_train(self, feat, out_enc, label, img_metas):
|
||||
@ -268,7 +273,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
valid_ratios = None
|
||||
if img_metas is not None and self.mask:
|
||||
valid_ratios = img_metas[-1]
|
||||
|
||||
|
||||
label = label.cuda()
|
||||
lab_embedding = self.embedding(label)
|
||||
# bsz * seq_len * emb_dim
|
||||
@ -277,11 +282,10 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
|
||||
# bsz * (seq_len + 1) * C
|
||||
out_dec = self._2d_attention(
|
||||
in_dec, feat, out_enc, valid_ratios=valid_ratios
|
||||
)
|
||||
in_dec, feat, out_enc, valid_ratios=valid_ratios)
|
||||
# bsz * (seq_len + 1) * num_classes
|
||||
|
||||
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
|
||||
|
||||
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
if img_metas is not None:
|
||||
@ -289,13 +293,12 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
|
||||
valid_ratios = None
|
||||
if img_metas is not None and self.mask:
|
||||
valid_ratios = img_metas[-1]
|
||||
|
||||
valid_ratios = img_metas[-1]
|
||||
|
||||
seq_len = self.max_seq_len
|
||||
bsz = feat.shape[0]
|
||||
start_token = paddle.full((bsz, ),
|
||||
fill_value=self.start_idx,
|
||||
dtype='int64')
|
||||
start_token = paddle.full(
|
||||
(bsz, ), fill_value=self.start_idx, dtype='int64')
|
||||
# bsz
|
||||
start_token = self.embedding(start_token)
|
||||
# bsz * emb_dim
|
||||
@ -311,68 +314,70 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
outputs = []
|
||||
for i in range(1, seq_len + 1):
|
||||
decoder_output = self._2d_attention(
|
||||
decoder_input, feat, out_enc, valid_ratios=valid_ratios
|
||||
)
|
||||
char_output = decoder_output[:, i, :] # bsz * num_classes
|
||||
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
|
||||
char_output = decoder_output[:, i, :] # bsz * num_classes
|
||||
char_output = F.softmax(char_output, -1)
|
||||
outputs.append(char_output)
|
||||
max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
|
||||
char_embedding = self.embedding(max_idx) # bsz * emb_dim
|
||||
char_embedding = self.embedding(max_idx) # bsz * emb_dim
|
||||
if i < seq_len:
|
||||
decoder_input[:, i + 1, :] = char_embedding
|
||||
|
||||
outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
|
||||
|
||||
outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SARHead(nn.Layer):
|
||||
def __init__(self,
|
||||
out_channels,
|
||||
enc_bi_rnn=False,
|
||||
enc_drop_rnn=0.1,
|
||||
enc_gru=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_drop_rnn=0.0,
|
||||
dec_gru=False,
|
||||
d_k=512,
|
||||
pred_dropout=0.1,
|
||||
max_text_length=30,
|
||||
pred_concat=True,
|
||||
**kwargs):
|
||||
def __init__(self,
|
||||
out_channels,
|
||||
enc_bi_rnn=False,
|
||||
enc_drop_rnn=0.1,
|
||||
enc_gru=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_drop_rnn=0.0,
|
||||
dec_gru=False,
|
||||
d_k=512,
|
||||
pred_dropout=0.1,
|
||||
max_text_length=30,
|
||||
pred_concat=True,
|
||||
**kwargs):
|
||||
super(SARHead, self).__init__()
|
||||
|
||||
# encoder module
|
||||
self.encoder = SAREncoder(
|
||||
enc_bi_rnn=enc_bi_rnn,
|
||||
enc_drop_rnn=enc_drop_rnn,
|
||||
enc_gru=enc_gru)
|
||||
enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
|
||||
|
||||
# decoder module
|
||||
self.decoder = ParallelSARDecoder(
|
||||
out_channels=out_channels,
|
||||
enc_bi_rnn=enc_bi_rnn,
|
||||
enc_bi_rnn=enc_bi_rnn,
|
||||
dec_bi_rnn=dec_bi_rnn,
|
||||
dec_drop_rnn=dec_drop_rnn,
|
||||
dec_gru=dec_gru,
|
||||
d_k=d_k,
|
||||
pred_dropout=pred_dropout,
|
||||
max_text_length=max_text_length,
|
||||
pred_concat=pred_concat)
|
||||
|
||||
pred_concat=pred_concat)
|
||||
|
||||
def forward(self, feat, targets=None):
|
||||
'''
|
||||
img_metas: [label, valid_ratio]
|
||||
'''
|
||||
holistic_feat = self.encoder(feat, targets) # bsz c
|
||||
|
||||
holistic_feat = self.encoder(feat, targets) # bsz c
|
||||
|
||||
if self.training:
|
||||
label = targets[0] # label
|
||||
label = targets[0] # label
|
||||
label = paddle.to_tensor(label, dtype='int64')
|
||||
final_out = self.decoder(feat, holistic_feat, label, img_metas=targets)
|
||||
final_out = self.decoder(
|
||||
feat, holistic_feat, label, img_metas=targets)
|
||||
if not self.training:
|
||||
final_out = self.decoder(feat, holistic_feat, label=None, img_metas=targets, train_mode=False)
|
||||
final_out = self.decoder(
|
||||
feat,
|
||||
holistic_feat,
|
||||
label=None,
|
||||
img_metas=targets,
|
||||
train_mode=False)
|
||||
# (bsz, seq_len, num_classes)
|
||||
|
||||
|
||||
return final_out
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user