mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-30 17:29:13 +00:00 
			
		
		
		
	 75d16610f4
			
		
	
	
		75d16610f4
		
			
		
	
	
	
	
		
			
			* Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md (#10616) * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update README.md * Cherrypicking GH-10217 and GH-10216 to PaddlePaddle:Release/2.7 (#10655) * Don't break overall processing on a bad image * Add preprocessing common to OCR tasks Add preprocessing to options * Update requirements.txt (#10656) added missing pyyaml library * [TIPC]update xpu tipc script (#10658) * fix-typo (#10642) Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> * 修改数据增强导致的DSR报错 (#10662) (#10681) * 修改数据增强导致的DSR报错 * 错误修改回滚 * Update algorithm_overview_en.md (#10670) Fixed simple spelling errors. * Implement recoginition method ParseQ * Document update for new recognition method ParseQ * add prediction for parseq * Update rec_vit_parseq.yml * Update rec_r31_sar.yml * Update rec_r31_sar.yml * Update rec_r50_fpn_srn.yml * Update rec_vit_parseq.py * Update rec_vit_parseq.yml * Update rec_parseq_head.py * Update rec_img_aug.py * Update rec_vit_parseq.yml * Update __init__.py * Update predict_rec.py * Update paddleocr.py * Update requirements.txt * Update utility.py * Update utility.py --------- Co-authored-by: xiaoting <31891223+tink2123@users.noreply.github.com> Co-authored-by: topduke <784990967@qq.com> Co-authored-by: dyning <dyning.2003@163.com> Co-authored-by: UserUnknownFactor <63057995+UserUnknownFactor@users.noreply.github.com> Co-authored-by: itasli <ilyas.tasli@outlook.fr> Co-authored-by: Kai Song <50285351+USTCKAY@users.noreply.github.com> Co-authored-by: dvorst <87502756+dvorst@users.noreply.github.com> Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> Co-authored-by: Dec20B <1192152456@qq.com> Co-authored-by: ncoffman <51147417+ncoffman@users.noreply.github.com>
		
			
				
	
	
		
			343 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			343 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #    http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| # Code was based on https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
 | |
| # reference: https://arxiv.org/abs/2207.06966
 | |
| 
 | |
| from __future__ import absolute_import
 | |
| from __future__ import division
 | |
| from __future__ import print_function
 | |
| 
 | |
| import math
 | |
| import paddle
 | |
| from paddle import nn, ParamAttr
 | |
| from paddle.nn import functional as F
 | |
| import numpy as np
 | |
| from .self_attention import WrapEncoderForFeature
 | |
| from .self_attention import WrapEncoder
 | |
| from collections import OrderedDict
 | |
| from typing import Optional
 | |
| import copy
 | |
| from itertools import permutations
 | |
| 
 | |
| 
 | |
| class DecoderLayer(paddle.nn.Layer):
 | |
|     """A Transformer decoder layer supporting two-stream attention (XLNet)
 | |
|        This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
 | |
| 
 | |
|     def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-05):
 | |
|         super().__init__()
 | |
|         self.self_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True) # paddle.nn.MultiHeadAttention默认为batch_first模式
 | |
|         self.cross_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True)
 | |
|         self.linear1 = paddle.nn.Linear(in_features=d_model, out_features=dim_feedforward)
 | |
|         self.dropout = paddle.nn.Dropout(p=dropout)
 | |
|         self.linear2 = paddle.nn.Linear(in_features=dim_feedforward, out_features=d_model)
 | |
|         self.norm1 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
 | |
|         self.norm2 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
 | |
|         self.norm_q = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
 | |
|         self.norm_c = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
 | |
|         self.dropout1 = paddle.nn.Dropout(p=dropout)
 | |
|         self.dropout2 = paddle.nn.Dropout(p=dropout)
 | |
|         self.dropout3 = paddle.nn.Dropout(p=dropout)
 | |
|         if activation == 'gelu':
 | |
|             self.activation = paddle.nn.GELU()
 | |
| 
 | |
|     def __setstate__(self, state):
 | |
|         if 'activation' not in state:
 | |
|             state['activation'] = paddle.nn.functional.gelu
 | |
|         super().__setstate__(state)
 | |
| 
 | |
|     def forward_stream(self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask):
 | |
|         """Forward pass for a single stream (i.e. content or query)
 | |
|         tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
 | |
|         Both tgt_kv and memory are expected to be LayerNorm'd too.
 | |
|         memory is LayerNorm'd by ViT.
 | |
|         """
 | |
|         if tgt_key_padding_mask is not None:
 | |
|             tgt_mask1 = (tgt_mask!=float('-inf'))[None,None,:,:] & (tgt_key_padding_mask[:,None,None,:]==False)
 | |
|             tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1)
 | |
|         else:
 | |
|             tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask)
 | |
| 
 | |
|         tgt = tgt + self.dropout1(tgt2)
 | |
|         tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
 | |
|         tgt = tgt + self.dropout2(tgt2)
 | |
|         tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
 | |
|         tgt = tgt + self.dropout3(tgt2)
 | |
|         return tgt, sa_weights, ca_weights
 | |
| 
 | |
|     def forward(self, query, content, memory, query_mask=None, content_mask=None, content_key_padding_mask=None, update_content=True):
 | |
|         query_norm = self.norm_q(query)
 | |
|         content_norm = self.norm_c(content)
 | |
|         query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
 | |
|         if update_content:
 | |
|             content = self.forward_stream(content, content_norm, content_norm, memory, content_mask, content_key_padding_mask)[0]
 | |
|         return query, content
 | |
| 
 | |
| 
 | |
| def get_clones(module, N):
 | |
|     return paddle.nn.LayerList([copy.deepcopy(module) for i in range(N)])
 | |
| 
 | |
| 
 | |
| class Decoder(paddle.nn.Layer):
 | |
|     __constants__ = ['norm']
 | |
| 
 | |
|     def __init__(self, decoder_layer, num_layers, norm):
 | |
|         super().__init__()
 | |
|         self.layers = get_clones(decoder_layer, num_layers)
 | |
|         self.num_layers = num_layers
 | |
|         self.norm = norm
 | |
| 
 | |
|     def forward(self, query, content, memory, query_mask: Optional[paddle.Tensor]=None, content_mask: Optional[paddle.Tensor]=None, content_key_padding_mask: Optional[paddle.Tensor]=None):
 | |
|         for i, mod in enumerate(self.layers):
 | |
|             last = i == len(self.layers) - 1
 | |
|             query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last)
 | |
|         query = self.norm(query)
 | |
|         return query
 | |
|     
 | |
| 
 | |
| class TokenEmbedding(paddle.nn.Layer):
 | |
| 
 | |
|     def __init__(self, charset_size: int, embed_dim: int):
 | |
|         super().__init__()
 | |
|         self.embedding = paddle.nn.Embedding(num_embeddings=charset_size, embedding_dim=embed_dim)
 | |
|         self.embed_dim = embed_dim
 | |
| 
 | |
|     def forward(self, tokens: paddle.Tensor):
 | |
|         return math.sqrt(self.embed_dim) * self.embedding(tokens.astype(paddle.int64))
 | |
| 
 | |
| 
 | |
| def trunc_normal_init(param, **kwargs):
 | |
|     initializer = nn.initializer.TruncatedNormal(**kwargs)
 | |
|     initializer(param, param.block)
 | |
| 
 | |
| 
 | |
| def constant_init(param, **kwargs):
 | |
|     initializer = nn.initializer.Constant(**kwargs)
 | |
|     initializer(param, param.block)
 | |
| 
 | |
| 
 | |
| def kaiming_normal_init(param, **kwargs):
 | |
|     initializer = nn.initializer.KaimingNormal(**kwargs)
 | |
|     initializer(param, param.block)
 | |
| 
 | |
| 
 | |
| class ParseQHead(nn.Layer):
 | |
|     def __init__(self, out_channels, max_text_length, embed_dim, dec_num_heads, dec_mlp_ratio, dec_depth, perm_num, perm_forward, perm_mirrored, decode_ar, refine_iters, dropout, **kwargs):
 | |
|         super().__init__()
 | |
| 
 | |
|         self.bos_id = out_channels - 2
 | |
|         self.eos_id = 0
 | |
|         self.pad_id = out_channels - 1
 | |
|         
 | |
|         self.max_label_length = max_text_length
 | |
|         self.decode_ar = decode_ar
 | |
|         self.refine_iters = refine_iters
 | |
|         decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
 | |
|         self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=paddle.nn.LayerNorm(normalized_shape=embed_dim))
 | |
|         self.rng = np.random.default_rng()
 | |
|         self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
 | |
|         self.perm_forward = perm_forward
 | |
|         self.perm_mirrored = perm_mirrored
 | |
|         self.head = paddle.nn.Linear(in_features=embed_dim, out_features=out_channels - 2)
 | |
|         self.text_embed = TokenEmbedding(out_channels, embed_dim)
 | |
|         self.pos_queries = paddle.create_parameter(shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape, dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype, default_initializer=paddle.nn.initializer.Assign(paddle.empty(shape=[1, max_text_length + 1, embed_dim])))
 | |
|         self.pos_queries.stop_gradient = not True
 | |
|         self.dropout = paddle.nn.Dropout(p=dropout)
 | |
|         self._device = self.parameters()[0].place
 | |
|         trunc_normal_init(self.pos_queries, std=0.02)
 | |
|         self.apply(self._init_weights)
 | |
| 
 | |
|     def _init_weights(self, m):
 | |
|         if isinstance(m, paddle.nn.Linear):
 | |
|             trunc_normal_init(m.weight, std=0.02)
 | |
|             if m.bias is not None:
 | |
|                 constant_init(m.bias, value=0.0)
 | |
|         elif isinstance(m, paddle.nn.Embedding):
 | |
|             trunc_normal_init(m.weight, std=0.02)
 | |
|             if m._padding_idx is not None:
 | |
|                 m.weight.data[m._padding_idx].zero_()
 | |
|         elif isinstance(m, paddle.nn.Conv2D):
 | |
|             kaiming_normal_init(m.weight, fan_in=None, nonlinearity='relu')
 | |
|             if m.bias is not None:
 | |
|                 constant_init(m.bias, value=0.0)
 | |
|         elif isinstance(m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)):
 | |
|                 constant_init(m.weight, value=1.0)
 | |
|                 constant_init(m.bias, value=0.0)
 | |
| 
 | |
|     def no_weight_decay(self):
 | |
|         param_names = {'text_embed.embedding.weight', 'pos_queries'}
 | |
|         enc_param_names = {('encoder.' + n) for n in self.encoder.
 | |
|             no_weight_decay()}
 | |
|         return param_names.union(enc_param_names)
 | |
| 
 | |
|     def encode(self, img):
 | |
|         return self.encoder(img)
 | |
| 
 | |
|     def decode(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, tgt_query=None, tgt_query_mask=None):
 | |
|         N, L = tgt.shape
 | |
|         null_ctx = self.text_embed(tgt[:, :1])
 | |
|         if L != 1:
 | |
|             tgt_emb = self.pos_queries[:, :L - 1] + self.text_embed(tgt[:, 1:])
 | |
|             tgt_emb = self.dropout(paddle.concat(x=[null_ctx, tgt_emb], axis=1))
 | |
|         else:
 | |
|             tgt_emb = self.dropout(null_ctx)
 | |
|         if tgt_query is None:
 | |
|             tgt_query = self.pos_queries[:, :L].expand(shape=[N, -1, -1])
 | |
|         tgt_query = self.dropout(tgt_query)
 | |
|         return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
 | |
| 
 | |
|     def forward_test(self, memory, max_length=None):
 | |
|         testing = max_length is None
 | |
|         max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
 | |
|         bs = memory.shape[0]
 | |
|         num_steps = max_length + 1
 | |
| 
 | |
|         pos_queries = self.pos_queries[:, :num_steps].expand(shape=[bs, -1, -1])
 | |
|         tgt_mask = query_mask = paddle.triu(x=paddle.full(shape=(num_steps, num_steps), fill_value=float('-inf')), diagonal=1)
 | |
|         if self.decode_ar:
 | |
|             tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype('int64')
 | |
|             tgt_in[:, (0)] = self.bos_id
 | |
| 
 | |
|             logits = []
 | |
|             for i in range(paddle.to_tensor(num_steps)):
 | |
|                 j = i + 1
 | |
|                 tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], tgt_query_mask=query_mask[i:j, :j])
 | |
|                 p_i = self.head(tgt_out)
 | |
|                 logits.append(p_i)
 | |
|                 if j < num_steps:
 | |
|                     tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
 | |
|                     if testing and (tgt_in == self.eos_id).astype('bool').any(axis=-1).astype('bool').all():
 | |
|                         break
 | |
|             logits = paddle.concat(x=logits, axis=1)
 | |
|         else:
 | |
|             tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64')
 | |
|             tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
 | |
|             logits = self.head(tgt_out)
 | |
|         if self.refine_iters:
 | |
|             temp = paddle.triu(x=paddle.ones(shape=[num_steps,num_steps], dtype='bool'), diagonal=2)
 | |
|             posi = np.where(temp.cpu().numpy()==True)
 | |
|             query_mask[posi] = 0
 | |
|             bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64')
 | |
|             for i in range(self.refine_iters):
 | |
|                 tgt_in = paddle.concat(x=[bos, logits[:, :-1].argmax(axis=-1)], axis=1)
 | |
|                 tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype='int32')
 | |
|                 tgt_padding_mask = tgt_padding_mask.cpu()
 | |
|                 tgt_padding_mask = tgt_padding_mask.cumsum(axis=-1) > 0
 | |
|                 tgt_padding_mask = tgt_padding_mask.cuda().astype(dtype='float32')==1.0
 | |
|                 tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]])
 | |
|                 logits = self.head(tgt_out)
 | |
| 
 | |
|         final_output = {"predict":logits}
 | |
| 
 | |
|         return final_output
 | |
| 
 | |
|     def gen_tgt_perms(self, tgt):
 | |
|         """Generate shared permutations for the whole batch.
 | |
|            This works because the same attention mask can be used for the shorter sequences
 | |
|            because of the padding mask.
 | |
|         """
 | |
|         max_num_chars = tgt.shape[1] - 2
 | |
|         if max_num_chars == 1:
 | |
|             return paddle.arange(end=3).unsqueeze(axis=0)
 | |
|         perms = [paddle.arange(end=max_num_chars)] if self.perm_forward else []
 | |
|         max_perms = math.factorial(max_num_chars)
 | |
|         if self.perm_mirrored:
 | |
|             max_perms //= 2
 | |
|         num_gen_perms = min(self.max_gen_perms, max_perms)
 | |
|         if max_num_chars < 5:
 | |
|             if max_num_chars == 4 and self.perm_mirrored:
 | |
|                 selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
 | |
|             else:
 | |
|                 selector = list(range(max_perms))
 | |
|             perm_pool = paddle.to_tensor(data=list(permutations(range(max_num_chars), max_num_chars)), place=self._device)[selector]
 | |
|             if self.perm_forward:
 | |
|                 perm_pool = perm_pool[1:]
 | |
|             perms = paddle.stack(x=perms)
 | |
|             if len(perm_pool):
 | |
|                 i = self.rng.choice(len(perm_pool), size=num_gen_perms -
 | |
|                     len(perms), replace=False)
 | |
|                 perms = paddle.concat(x=[perms, perm_pool[i]])
 | |
|         else:
 | |
|             perms.extend([paddle.randperm(n=max_num_chars) for _ in range(num_gen_perms - len(perms))])
 | |
|             perms = paddle.stack(x=perms)
 | |
|         if self.perm_mirrored:
 | |
|             comp = perms.flip(axis=-1)
 | |
|             x = paddle.stack(x=[perms, comp])
 | |
|             perm_2 = list(range(x.ndim))
 | |
|             perm_2[0] = 1
 | |
|             perm_2[1] = 0
 | |
|             perms = x.transpose(perm=perm_2).reshape((-1, max_num_chars))
 | |
|         bos_idx = paddle.zeros(shape=(len(perms), 1), dtype=perms.dtype)
 | |
|         eos_idx = paddle.full(shape=(len(perms), 1), fill_value=
 | |
|             max_num_chars + 1, dtype=perms.dtype)
 | |
|         perms = paddle.concat(x=[bos_idx, perms + 1, eos_idx], axis=1)
 | |
|         if len(perms) > 1:
 | |
|             perms[(1), 1:] = max_num_chars + 1 - paddle.arange(end=max_num_chars + 1)
 | |
|         return perms
 | |
| 
 | |
|     def generate_attn_masks(self, perm):
 | |
|         """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
 | |
|         :param perm: the permutation sequence. i = 0 is always the BOS
 | |
|         :return: lookahead attention masks
 | |
|         """
 | |
|         sz = perm.shape[0]
 | |
|         mask = paddle.zeros(shape=(sz, sz))
 | |
|         for i in range(sz):
 | |
|             query_idx = perm[i].cpu().numpy().tolist()
 | |
|             masked_keys = perm[i + 1:].cpu().numpy().tolist()
 | |
|             if len(masked_keys) == 0:
 | |
|                 break
 | |
|             mask[query_idx, masked_keys] = float('-inf')
 | |
|         content_mask = mask[:-1, :-1].clone()
 | |
|         mask[paddle.eye(num_rows=sz).astype('bool')] = float('-inf')
 | |
|         query_mask = mask[1:, :-1]
 | |
|         return content_mask, query_mask
 | |
| 
 | |
|     def forward_train(self, memory, tgt):
 | |
|         tgt_perms = self.gen_tgt_perms(tgt)
 | |
|         tgt_in = tgt[:, :-1]
 | |
|         tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
 | |
|         logits_list = []
 | |
|         final_out = {}
 | |
|         for i, perm in enumerate(tgt_perms):
 | |
|             tgt_mask, query_mask = self.generate_attn_masks(perm)
 | |
|             out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask)
 | |
|             logits = self.head(out)
 | |
|             if i == 0:
 | |
|                 final_out['predict'] = logits
 | |
|             logits = logits.flatten(stop_axis=1)
 | |
|             logits_list.append(logits)
 | |
| 
 | |
|         final_out['logits_list'] = logits_list
 | |
|         final_out['pad_id'] = self.pad_id
 | |
|         final_out['eos_id'] = self.eos_id
 | |
| 
 | |
|         return final_out
 | |
| 
 | |
|     def forward(self, feat, targets=None):
 | |
|         # feat : B, N, C
 | |
|         # targets : labels, labels_len
 | |
| 
 | |
|         if self.training:
 | |
|             label = targets[0]  # label
 | |
|             label_len = targets[1]
 | |
|             max_step = paddle.max(label_len).cpu().numpy()[0] + 2
 | |
|             crop_label = label[:, :max_step]
 | |
|             final_out = self.forward_train(feat, crop_label)
 | |
|         else:
 | |
|             final_out = self.forward_test(feat)
 | |
| 
 | |
|         return final_out
 |