mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-03 03:09:16 +00:00
use tensor.shape bug not paddle.shape(tensor) (#11919)
* use tensor.shape bug not paddle.shape(tensor) * refine * refine
This commit is contained in:
parent
d303d5f7b4
commit
2b3b3554c0
@ -42,7 +42,7 @@ class MTB(nn.Layer):
|
||||
if self.cnn_num == 2:
|
||||
# (b, w, h, c)
|
||||
x = paddle.transpose(x, [0, 3, 2, 1])
|
||||
x_shape = paddle.shape(x)
|
||||
x_shape = x.shape
|
||||
x = paddle.reshape(
|
||||
x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
|
||||
return x
|
||||
|
||||
@ -33,7 +33,7 @@ def drop_path(x, drop_prob=0., training=False):
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
|
||||
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
|
||||
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
|
||||
random_tensor = paddle.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
|
||||
@ -33,7 +33,7 @@ def drop_path(x, drop_prob=0., training=False):
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = paddle.to_tensor(1 - drop_prob)
|
||||
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
|
||||
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
|
||||
random_tensor = paddle.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
@ -243,7 +243,7 @@ class ViT(nn.Layer):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x).flatten(2).transpose((0, 2, 1))
|
||||
x = x + self.pos_embed[:, 1:, :] #[:, :paddle.shape(x)[1], :]
|
||||
x = x + self.pos_embed[:, 1:, :] #[:, :x.shape[1], :]
|
||||
x = self.pos_drop(x)
|
||||
for blk in self.blocks1:
|
||||
x = blk(x)
|
||||
|
||||
@ -43,7 +43,7 @@ def drop_path(x, drop_prob=0., training=False):
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
|
||||
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
|
||||
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
|
||||
random_tensor = paddle.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
@ -113,7 +113,7 @@ class Attention(nn.Layer):
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
# B= paddle.shape(x)[0]
|
||||
# B= x.shape[0]
|
||||
N, C = x.shape[1:]
|
||||
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
|
||||
self.num_heads)).transpose((2, 0, 3, 1, 4))
|
||||
@ -280,7 +280,7 @@ class VisionTransformer(nn.Layer):
|
||||
ones_(m.weight)
|
||||
|
||||
def forward_features(self, x):
|
||||
B = paddle.shape(x)[0]
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
@ -285,7 +285,7 @@ def _get_mask(length, max_length):
|
||||
Unmasked positions are filled with float(0.0).
|
||||
"""
|
||||
length = length.unsqueeze(-1)
|
||||
B = paddle.shape(length)[0]
|
||||
B = length.shape[0]
|
||||
grid = paddle.arange(0, max_length).unsqueeze(0).tile([B, 1])
|
||||
zero_mask = paddle.zeros([B, max_length], dtype='float32')
|
||||
inf_mask = paddle.full([B, max_length], '-inf', dtype='float32')
|
||||
|
||||
@ -81,7 +81,7 @@ class Embedding(nn.Layer):
|
||||
self.embed_dim) # Embed encoder output to a word-embedding like
|
||||
|
||||
def forward(self, x):
|
||||
x = paddle.reshape(x, [paddle.shape(x)[0], -1])
|
||||
x = paddle.reshape(x, [x.shape[0], -1])
|
||||
x = self.eEmbed(x)
|
||||
return x
|
||||
|
||||
@ -105,7 +105,7 @@ class AttentionRecognitionHead(nn.Layer):
|
||||
|
||||
def forward(self, x, embed):
|
||||
x, targets, lengths = x
|
||||
batch_size = paddle.shape(x)[0]
|
||||
batch_size = x.shape[0]
|
||||
# Decoder
|
||||
state = self.decoder.get_initial_state(embed)
|
||||
outputs = []
|
||||
|
||||
@ -38,7 +38,7 @@ class AttentionHead(nn.Layer):
|
||||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||
batch_size = paddle.shape(inputs)[0]
|
||||
batch_size = inputs.shape[0]
|
||||
num_steps = batch_max_length
|
||||
|
||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
|
||||
@ -294,7 +294,7 @@ class CPPDHead(nn.Layer):
|
||||
char_node_embed = self.char_node_embed(
|
||||
paddle.arange(self.out_channels)).unsqueeze(0)
|
||||
char_node_embed = paddle.tile(char_node_embed, [bs, 1, 1])
|
||||
counting_char_num = paddle.shape(char_node_embed)[1]
|
||||
counting_char_num = char_node_embed.shape[1]
|
||||
pos_node_embed = self.pos_node_embed(paddle.arange(
|
||||
self.max_len)).unsqueeze(0) + self.char_pos_embed
|
||||
pos_node_embed = paddle.tile(pos_node_embed, [bs, 1, 1])
|
||||
|
||||
@ -50,7 +50,7 @@ class AddPos(nn.Layer):
|
||||
trunc_normal_(self.dec_pos_embed)
|
||||
|
||||
def forward(self,x):
|
||||
x = x + self.dec_pos_embed[:, :paddle.shape(x)[1], :]
|
||||
x = x + self.dec_pos_embed[:, :x.shape[1], :]
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ class Transformer(nn.Layer):
|
||||
|
||||
def forward_test(self, src):
|
||||
|
||||
bs = paddle.shape(src)[0]
|
||||
bs = src.shape[0]
|
||||
if self.encoder is not None:
|
||||
src = self.positional_encoding(src)
|
||||
for encoder_layer in self.encoder:
|
||||
@ -164,7 +164,7 @@ class Transformer(nn.Layer):
|
||||
dec_seq_embed = self.embedding(dec_seq)
|
||||
dec_seq_embed = self.positional_encoding(dec_seq_embed)
|
||||
tgt_mask = self.generate_square_subsequent_mask(
|
||||
paddle.shape(dec_seq_embed)[1])
|
||||
dec_seq_embed.shape[1])
|
||||
tgt = dec_seq_embed
|
||||
for decoder_layer in self.decoder:
|
||||
tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
|
||||
@ -175,7 +175,7 @@ class Transformer(nn.Layer):
|
||||
if paddle.equal_all(
|
||||
preds_idx,
|
||||
paddle.full(
|
||||
paddle.shape(preds_idx), 3, dtype='int64')):
|
||||
preds_idx.shape, 3, dtype='int64')):
|
||||
break
|
||||
preds_prob = paddle.max(word_prob, axis=-1)
|
||||
dec_seq = paddle.concat(
|
||||
@ -198,7 +198,7 @@ class Transformer(nn.Layer):
|
||||
n_prev_active_inst, n_bm):
|
||||
""" Collect tensor parts associated to active instances. """
|
||||
|
||||
beamed_tensor_shape = paddle.shape(beamed_tensor)
|
||||
beamed_tensor_shape = beamed_tensor.shape
|
||||
n_curr_active_inst = len(curr_active_inst_idx)
|
||||
new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
|
||||
beamed_tensor_shape[2])
|
||||
@ -243,7 +243,7 @@ class Transformer(nn.Layer):
|
||||
dec_seq = self.embedding(dec_seq)
|
||||
dec_seq = self.positional_encoding(dec_seq)
|
||||
tgt_mask = self.generate_square_subsequent_mask(
|
||||
paddle.shape(dec_seq)[1])
|
||||
dec_seq.shape[1])
|
||||
tgt = dec_seq
|
||||
for decoder_layer in self.decoder:
|
||||
tgt = decoder_layer(tgt, enc_output, self_mask=tgt_mask)
|
||||
@ -294,7 +294,7 @@ class Transformer(nn.Layer):
|
||||
src_enc = images
|
||||
|
||||
n_bm = self.beam_size
|
||||
src_shape = paddle.shape(src_enc)
|
||||
src_shape = src_enc.shape
|
||||
inst_dec_beams = [Beam(n_bm) for _ in range(1)]
|
||||
active_inst_idx_list = list(range(1))
|
||||
# Repeat data for beam search
|
||||
@ -500,7 +500,7 @@ class PositionalEncoding(nn.Layer):
|
||||
>>> output = pos_encoder(x)
|
||||
"""
|
||||
x = x.transpose([1, 0, 2])
|
||||
x = x + self.pe[:paddle.shape(x)[0], :]
|
||||
x = x + self.pe[:x.shape[0], :]
|
||||
return self.dropout(x).transpose([1, 0, 2])
|
||||
|
||||
|
||||
@ -552,13 +552,13 @@ class PositionalEncoding_2d(nn.Layer):
|
||||
Examples:
|
||||
>>> output = pos_encoder(x)
|
||||
"""
|
||||
w_pe = self.pe[:paddle.shape(x)[-1], :]
|
||||
w_pe = self.pe[:x.shape[-1], :]
|
||||
w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
|
||||
w_pe = w_pe * w1
|
||||
w_pe = paddle.transpose(w_pe, [1, 2, 0])
|
||||
w_pe = paddle.unsqueeze(w_pe, 2)
|
||||
|
||||
h_pe = self.pe[:paddle.shape(x).shape[-2], :]
|
||||
h_pe = self.pe[:x.shape.shape[-2], :]
|
||||
w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
|
||||
h_pe = h_pe * w2
|
||||
h_pe = paddle.transpose(h_pe, [1, 2, 0])
|
||||
|
||||
@ -83,7 +83,7 @@ class SAREncoder(nn.Layer):
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
if img_metas is not None:
|
||||
assert len(img_metas[0]) == paddle.shape(feat)[0]
|
||||
assert len(img_metas[0]) == feat.shape[0]
|
||||
|
||||
valid_ratios = None
|
||||
if img_metas is not None and self.mask:
|
||||
@ -99,7 +99,7 @@ class SAREncoder(nn.Layer):
|
||||
if valid_ratios is not None:
|
||||
valid_hf = []
|
||||
T = paddle.shape(holistic_feat)[1]
|
||||
for i in range(paddle.shape(valid_ratios)[0]):
|
||||
for i in range(valid_ratios.shape[0]):
|
||||
valid_step = paddle.minimum(
|
||||
T, paddle.ceil(valid_ratios[i] * T).astype(T.dtype)) - 1
|
||||
valid_hf.append(holistic_feat[i, valid_step, :])
|
||||
@ -253,7 +253,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
|
||||
if valid_ratios is not None:
|
||||
# cal mask of attention weight
|
||||
for i in range(paddle.shape(valid_ratios)[0]):
|
||||
for i in range(valid_ratios.shape[0]):
|
||||
valid_width = paddle.minimum(
|
||||
w, paddle.ceil(valid_ratios[i] * w).astype("int32"))
|
||||
if valid_width < w:
|
||||
@ -292,7 +292,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
img_metas: [label, valid_ratio]
|
||||
'''
|
||||
if img_metas is not None:
|
||||
assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0]
|
||||
assert img_metas[0].shape[0] == feat.shape[0]
|
||||
|
||||
valid_ratios = None
|
||||
if img_metas is not None and self.mask:
|
||||
|
||||
@ -283,7 +283,7 @@ class SATRNEncoder(nn.Layer):
|
||||
Tensor: A tensor of shape :math:`(N, T, D_m)`.
|
||||
"""
|
||||
if valid_ratios is None:
|
||||
bs = paddle.shape(feat)[0]
|
||||
bs = feat.shape[0]
|
||||
valid_ratios = paddle.full((bs, 1), 1., dtype=paddle.float32)
|
||||
|
||||
feat = self.position_enc(feat)
|
||||
|
||||
@ -42,7 +42,7 @@ class SPINAttentionHead(nn.Layer):
|
||||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||
batch_size = paddle.shape(inputs)[0]
|
||||
batch_size = inputs.shape[0]
|
||||
num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence
|
||||
|
||||
hidden = (paddle.zeros((batch_size, self.hidden_size)),
|
||||
|
||||
@ -78,7 +78,7 @@ class MultiHeadedAttention(nn.Layer):
|
||||
def forward(self, query, key, value, mask=None, attention_map=None):
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1)
|
||||
nbatches = paddle.shape(query)[0]
|
||||
nbatches = query.shape[0]
|
||||
|
||||
query, key, value = \
|
||||
[paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
|
||||
@ -230,7 +230,7 @@ class PositionalEncoding(nn.Layer):
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:, :paddle.shape(x)[1]]
|
||||
x = x + self.pe[:, :x.shape[1]]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ class TableMasterHead(nn.Layer):
|
||||
"""
|
||||
trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3)
|
||||
|
||||
tgt_len = paddle.shape(tgt)[1]
|
||||
tgt_len = tgt.shape[1]
|
||||
trg_sub_mask = paddle.tril(
|
||||
paddle.ones(
|
||||
([tgt_len, tgt_len]), dtype=paddle.float32))
|
||||
@ -279,5 +279,5 @@ class PositionalEncoding(nn.Layer):
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, feat, **kwargs):
|
||||
feat = feat + self.pe[:, :paddle.shape(feat)[1]] # pe 1*5000*512
|
||||
feat = feat + self.pe[:, :feat.shape[1]] # pe 1*5000*512
|
||||
return self.dropout(feat)
|
||||
|
||||
@ -305,7 +305,7 @@ class CSPPAN(nn.Layer):
|
||||
feat_heigh = inner_outs[0]
|
||||
feat_low = inputs[idx - 1]
|
||||
upsample_feat = F.upsample(
|
||||
feat_heigh, size=paddle.shape(feat_low)[2:4], mode="nearest")
|
||||
feat_heigh, size=feat_low.shape[2:4], mode="nearest")
|
||||
|
||||
inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
|
||||
paddle.concat([upsample_feat, feat_low], 1))
|
||||
|
||||
@ -222,7 +222,7 @@ class Cross_Attention(nn.Layer):
|
||||
return f_weight
|
||||
|
||||
def forward(self, f_common):
|
||||
f_shape = paddle.shape(f_common)
|
||||
f_shape = f_common.shape
|
||||
# print('f_shape: ', f_shape)
|
||||
|
||||
f_theta = self.theta_conv(f_common)
|
||||
|
||||
@ -80,7 +80,7 @@ class FeatureEnhancer(nn.Layer):
|
||||
global_info: (batch, embedding_size, 1, 1)
|
||||
conv_feature: (batch, channel, H, W)
|
||||
'''
|
||||
batch = paddle.shape(conv_feature)[0]
|
||||
batch = conv_feature.shape[0]
|
||||
position2d = positionalencoding2d(
|
||||
64, 16, 64).cast('float32').unsqueeze(0).reshape([1, 64, 1024])
|
||||
position2d = position2d.tile([batch, 1, 1])
|
||||
@ -276,7 +276,7 @@ class RecurrentResidualBlock(nn.Layer):
|
||||
residual = self.conv2(residual)
|
||||
residual = self.bn2(residual)
|
||||
|
||||
size = paddle.shape(residual)
|
||||
size = residual.shape
|
||||
residual = residual.reshape([size[0], size[1], -1])
|
||||
residual = self.feature_enhancer(residual)
|
||||
residual = residual.reshape([size[0], size[1], size[2], size[3]])
|
||||
|
||||
@ -152,7 +152,7 @@ class TPSSpatialTransformer(nn.Layer):
|
||||
assert source_control_points.ndimension() == 3
|
||||
assert source_control_points.shape[1] == self.num_control_points
|
||||
assert source_control_points.shape[2] == 2
|
||||
batch_size = paddle.shape(source_control_points)[0]
|
||||
batch_size = source_control_points.shape[0]
|
||||
|
||||
padding_matrix = paddle.expand(
|
||||
self.padding_matrix, shape=[batch_size, 3, 2])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user