use tensor.shape bug not paddle.shape(tensor) (#11919)

* use tensor.shape bug not paddle.shape(tensor)

* refine

* refine
This commit is contained in:
wanghuancoder 2024-04-17 10:54:59 +08:00 committed by GitHub
parent d303d5f7b4
commit 2b3b3554c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 37 additions and 37 deletions

View File

@ -42,7 +42,7 @@ class MTB(nn.Layer):
if self.cnn_num == 2:
# (b, w, h, c)
x = paddle.transpose(x, [0, 3, 2, 1])
x_shape = paddle.shape(x)
x_shape = x.shape
x = paddle.reshape(
x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
return x

View File

@ -33,7 +33,7 @@ def drop_path(x, drop_prob=0., training=False):
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor

View File

@ -33,7 +33,7 @@ def drop_path(x, drop_prob=0., training=False):
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
@ -243,7 +243,7 @@ class ViT(nn.Layer):
def forward(self, x):
x = self.patch_embed(x).flatten(2).transpose((0, 2, 1))
x = x + self.pos_embed[:, 1:, :] #[:, :paddle.shape(x)[1], :]
x = x + self.pos_embed[:, 1:, :] #[:, :x.shape[1], :]
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)

View File

@ -43,7 +43,7 @@ def drop_path(x, drop_prob=0., training=False):
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
@ -113,7 +113,7 @@ class Attention(nn.Layer):
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
# B= paddle.shape(x)[0]
# B= x.shape[0]
N, C = x.shape[1:]
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
self.num_heads)).transpose((2, 0, 3, 1, 4))
@ -280,7 +280,7 @@ class VisionTransformer(nn.Layer):
ones_(m.weight)
def forward_features(self, x):
B = paddle.shape(x)[0]
B = x.shape[0]
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)

View File

@ -285,7 +285,7 @@ def _get_mask(length, max_length):
Unmasked positions are filled with float(0.0).
"""
length = length.unsqueeze(-1)
B = paddle.shape(length)[0]
B = length.shape[0]
grid = paddle.arange(0, max_length).unsqueeze(0).tile([B, 1])
zero_mask = paddle.zeros([B, max_length], dtype='float32')
inf_mask = paddle.full([B, max_length], '-inf', dtype='float32')

View File

@ -81,7 +81,7 @@ class Embedding(nn.Layer):
self.embed_dim) # Embed encoder output to a word-embedding like
def forward(self, x):
x = paddle.reshape(x, [paddle.shape(x)[0], -1])
x = paddle.reshape(x, [x.shape[0], -1])
x = self.eEmbed(x)
return x
@ -105,7 +105,7 @@ class AttentionRecognitionHead(nn.Layer):
def forward(self, x, embed):
x, targets, lengths = x
batch_size = paddle.shape(x)[0]
batch_size = x.shape[0]
# Decoder
state = self.decoder.get_initial_state(embed)
outputs = []

View File

@ -38,7 +38,7 @@ class AttentionHead(nn.Layer):
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = paddle.shape(inputs)[0]
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = paddle.zeros((batch_size, self.hidden_size))

View File

@ -294,7 +294,7 @@ class CPPDHead(nn.Layer):
char_node_embed = self.char_node_embed(
paddle.arange(self.out_channels)).unsqueeze(0)
char_node_embed = paddle.tile(char_node_embed, [bs, 1, 1])
counting_char_num = paddle.shape(char_node_embed)[1]
counting_char_num = char_node_embed.shape[1]
pos_node_embed = self.pos_node_embed(paddle.arange(
self.max_len)).unsqueeze(0) + self.char_pos_embed
pos_node_embed = paddle.tile(pos_node_embed, [bs, 1, 1])

View File

@ -50,7 +50,7 @@ class AddPos(nn.Layer):
trunc_normal_(self.dec_pos_embed)
def forward(self,x):
x = x + self.dec_pos_embed[:, :paddle.shape(x)[1], :]
x = x + self.dec_pos_embed[:, :x.shape[1], :]
return x

View File

@ -150,7 +150,7 @@ class Transformer(nn.Layer):
def forward_test(self, src):
bs = paddle.shape(src)[0]
bs = src.shape[0]
if self.encoder is not None:
src = self.positional_encoding(src)
for encoder_layer in self.encoder:
@ -164,7 +164,7 @@ class Transformer(nn.Layer):
dec_seq_embed = self.embedding(dec_seq)
dec_seq_embed = self.positional_encoding(dec_seq_embed)
tgt_mask = self.generate_square_subsequent_mask(
paddle.shape(dec_seq_embed)[1])
dec_seq_embed.shape[1])
tgt = dec_seq_embed
for decoder_layer in self.decoder:
tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
@ -175,7 +175,7 @@ class Transformer(nn.Layer):
if paddle.equal_all(
preds_idx,
paddle.full(
paddle.shape(preds_idx), 3, dtype='int64')):
preds_idx.shape, 3, dtype='int64')):
break
preds_prob = paddle.max(word_prob, axis=-1)
dec_seq = paddle.concat(
@ -198,7 +198,7 @@ class Transformer(nn.Layer):
n_prev_active_inst, n_bm):
""" Collect tensor parts associated to active instances. """
beamed_tensor_shape = paddle.shape(beamed_tensor)
beamed_tensor_shape = beamed_tensor.shape
n_curr_active_inst = len(curr_active_inst_idx)
new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
beamed_tensor_shape[2])
@ -243,7 +243,7 @@ class Transformer(nn.Layer):
dec_seq = self.embedding(dec_seq)
dec_seq = self.positional_encoding(dec_seq)
tgt_mask = self.generate_square_subsequent_mask(
paddle.shape(dec_seq)[1])
dec_seq.shape[1])
tgt = dec_seq
for decoder_layer in self.decoder:
tgt = decoder_layer(tgt, enc_output, self_mask=tgt_mask)
@ -294,7 +294,7 @@ class Transformer(nn.Layer):
src_enc = images
n_bm = self.beam_size
src_shape = paddle.shape(src_enc)
src_shape = src_enc.shape
inst_dec_beams = [Beam(n_bm) for _ in range(1)]
active_inst_idx_list = list(range(1))
# Repeat data for beam search
@ -500,7 +500,7 @@ class PositionalEncoding(nn.Layer):
>>> output = pos_encoder(x)
"""
x = x.transpose([1, 0, 2])
x = x + self.pe[:paddle.shape(x)[0], :]
x = x + self.pe[:x.shape[0], :]
return self.dropout(x).transpose([1, 0, 2])
@ -552,13 +552,13 @@ class PositionalEncoding_2d(nn.Layer):
Examples:
>>> output = pos_encoder(x)
"""
w_pe = self.pe[:paddle.shape(x)[-1], :]
w_pe = self.pe[:x.shape[-1], :]
w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
w_pe = w_pe * w1
w_pe = paddle.transpose(w_pe, [1, 2, 0])
w_pe = paddle.unsqueeze(w_pe, 2)
h_pe = self.pe[:paddle.shape(x).shape[-2], :]
h_pe = self.pe[:x.shape.shape[-2], :]
w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
h_pe = h_pe * w2
h_pe = paddle.transpose(h_pe, [1, 2, 0])

View File

@ -83,7 +83,7 @@ class SAREncoder(nn.Layer):
def forward(self, feat, img_metas=None):
if img_metas is not None:
assert len(img_metas[0]) == paddle.shape(feat)[0]
assert len(img_metas[0]) == feat.shape[0]
valid_ratios = None
if img_metas is not None and self.mask:
@ -99,7 +99,7 @@ class SAREncoder(nn.Layer):
if valid_ratios is not None:
valid_hf = []
T = paddle.shape(holistic_feat)[1]
for i in range(paddle.shape(valid_ratios)[0]):
for i in range(valid_ratios.shape[0]):
valid_step = paddle.minimum(
T, paddle.ceil(valid_ratios[i] * T).astype(T.dtype)) - 1
valid_hf.append(holistic_feat[i, valid_step, :])
@ -253,7 +253,7 @@ class ParallelSARDecoder(BaseDecoder):
if valid_ratios is not None:
# cal mask of attention weight
for i in range(paddle.shape(valid_ratios)[0]):
for i in range(valid_ratios.shape[0]):
valid_width = paddle.minimum(
w, paddle.ceil(valid_ratios[i] * w).astype("int32"))
if valid_width < w:
@ -292,7 +292,7 @@ class ParallelSARDecoder(BaseDecoder):
img_metas: [label, valid_ratio]
'''
if img_metas is not None:
assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0]
assert img_metas[0].shape[0] == feat.shape[0]
valid_ratios = None
if img_metas is not None and self.mask:

View File

@ -283,7 +283,7 @@ class SATRNEncoder(nn.Layer):
Tensor: A tensor of shape :math:`(N, T, D_m)`.
"""
if valid_ratios is None:
bs = paddle.shape(feat)[0]
bs = feat.shape[0]
valid_ratios = paddle.full((bs, 1), 1., dtype=paddle.float32)
feat = self.position_enc(feat)

View File

@ -42,7 +42,7 @@ class SPINAttentionHead(nn.Layer):
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = paddle.shape(inputs)[0]
batch_size = inputs.shape[0]
num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence
hidden = (paddle.zeros((batch_size, self.hidden_size)),

View File

@ -78,7 +78,7 @@ class MultiHeadedAttention(nn.Layer):
def forward(self, query, key, value, mask=None, attention_map=None):
if mask is not None:
mask = mask.unsqueeze(1)
nbatches = paddle.shape(query)[0]
nbatches = query.shape[0]
query, key, value = \
[paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
@ -230,7 +230,7 @@ class PositionalEncoding(nn.Layer):
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :paddle.shape(x)[1]]
x = x + self.pe[:, :x.shape[1]]
return self.dropout(x)

View File

@ -71,7 +71,7 @@ class TableMasterHead(nn.Layer):
"""
trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3)
tgt_len = paddle.shape(tgt)[1]
tgt_len = tgt.shape[1]
trg_sub_mask = paddle.tril(
paddle.ones(
([tgt_len, tgt_len]), dtype=paddle.float32))
@ -279,5 +279,5 @@ class PositionalEncoding(nn.Layer):
self.register_buffer('pe', pe)
def forward(self, feat, **kwargs):
feat = feat + self.pe[:, :paddle.shape(feat)[1]] # pe 1*5000*512
feat = feat + self.pe[:, :feat.shape[1]] # pe 1*5000*512
return self.dropout(feat)

View File

@ -305,7 +305,7 @@ class CSPPAN(nn.Layer):
feat_heigh = inner_outs[0]
feat_low = inputs[idx - 1]
upsample_feat = F.upsample(
feat_heigh, size=paddle.shape(feat_low)[2:4], mode="nearest")
feat_heigh, size=feat_low.shape[2:4], mode="nearest")
inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
paddle.concat([upsample_feat, feat_low], 1))

View File

@ -222,7 +222,7 @@ class Cross_Attention(nn.Layer):
return f_weight
def forward(self, f_common):
f_shape = paddle.shape(f_common)
f_shape = f_common.shape
# print('f_shape: ', f_shape)
f_theta = self.theta_conv(f_common)

View File

@ -80,7 +80,7 @@ class FeatureEnhancer(nn.Layer):
global_info: (batch, embedding_size, 1, 1)
conv_feature: (batch, channel, H, W)
'''
batch = paddle.shape(conv_feature)[0]
batch = conv_feature.shape[0]
position2d = positionalencoding2d(
64, 16, 64).cast('float32').unsqueeze(0).reshape([1, 64, 1024])
position2d = position2d.tile([batch, 1, 1])
@ -276,7 +276,7 @@ class RecurrentResidualBlock(nn.Layer):
residual = self.conv2(residual)
residual = self.bn2(residual)
size = paddle.shape(residual)
size = residual.shape
residual = residual.reshape([size[0], size[1], -1])
residual = self.feature_enhancer(residual)
residual = residual.reshape([size[0], size[1], size[2], size[3]])

View File

@ -152,7 +152,7 @@ class TPSSpatialTransformer(nn.Layer):
assert source_control_points.ndimension() == 3
assert source_control_points.shape[1] == self.num_control_points
assert source_control_points.shape[2] == 2
batch_size = paddle.shape(source_control_points)[0]
batch_size = source_control_points.shape[0]
padding_matrix = paddle.expand(
self.padding_matrix, shape=[batch_size, 3, 2])