mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-27 15:08:17 +00:00
add satrn (#8433)
* add satrn * 修复satrn导出问题 * 规范satrn config文件 * 删除SATRNRecResizeImg --------- Co-authored-by: zhiminzhang0830 <zhangzhimin04@baidu.com>
This commit is contained in:
parent
3ded6010e2
commit
30201ef954
117
configs/rec/rec_satrn.yml
Normal file
117
configs/rec/rec_satrn.yml
Normal file
@ -0,0 +1,117 @@
|
||||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 5
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 50
|
||||
save_model_dir: ./output/rec/rec_satrn/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 5000 iterations
|
||||
eval_batch_step: [0, 5000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict90.txt
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
rm_symbol: True
|
||||
save_res_path: ./output/rec/predicts_satrn.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [3, 4]
|
||||
values: [0.0003, 0.00003, 0.000003]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SATRN
|
||||
Backbone:
|
||||
name: ShallowCNN
|
||||
in_channels: 3
|
||||
hidden_dim: 256
|
||||
Head:
|
||||
name: SATRNHead
|
||||
enc_cfg:
|
||||
n_layers: 6
|
||||
n_head: 8
|
||||
d_k: 32
|
||||
d_v: 32
|
||||
d_model: 256
|
||||
n_position: 100
|
||||
d_inner: 1024
|
||||
dropout: 0.1
|
||||
dec_cfg:
|
||||
n_layers: 6
|
||||
d_embedding: 256
|
||||
n_head: 8
|
||||
d_model: 256
|
||||
d_inner: 1024
|
||||
d_k: 32
|
||||
d_v: 32
|
||||
max_seq_len: 25
|
||||
start_idx: 91
|
||||
|
||||
Loss:
|
||||
name: SATRNLoss
|
||||
|
||||
PostProcess:
|
||||
name: SATRNLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SATRNLabelEncode: # Class handling label
|
||||
- SVTRRecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 128
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
use_shared_memory: False
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/evaluation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SATRNLabelEncode: # Class handling label
|
||||
- SVTRRecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order
|
||||
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 128
|
||||
num_workers: 4
|
||||
use_shared_memory: False
|
||||
|
||||
@ -886,6 +886,62 @@ class SARLabelEncode(BaseRecLabelEncode):
|
||||
return [self.padding_idx]
|
||||
|
||||
|
||||
class SATRNLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
lower=False,
|
||||
**kwargs):
|
||||
super(SATRNLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char)
|
||||
self.lower = lower
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
beg_end_str = "<BOS/EOS>"
|
||||
unknown_str = "<UKN>"
|
||||
padding_str = "<PAD>"
|
||||
dict_character = dict_character + [unknown_str]
|
||||
self.unknown_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [beg_end_str]
|
||||
self.start_idx = len(dict_character) - 1
|
||||
self.end_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [padding_str]
|
||||
self.padding_idx = len(dict_character) - 1
|
||||
|
||||
return dict_character
|
||||
|
||||
def encode(self, text):
|
||||
if self.lower:
|
||||
text = text.lower()
|
||||
text_list = []
|
||||
for char in text:
|
||||
text_list.append(self.dict.get(char, self.unknown_idx))
|
||||
if len(text_list) == 0:
|
||||
return None
|
||||
return text_list
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
target = [self.start_idx] + text + [self.end_idx]
|
||||
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
|
||||
if len(target) > self.max_text_len:
|
||||
padded_text = target[:self.max_text_len]
|
||||
else:
|
||||
padded_text[:len(target)] = target
|
||||
data['label'] = np.array(padded_text)
|
||||
return data
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [self.padding_idx]
|
||||
|
||||
|
||||
class PRENLabelEncode(BaseRecLabelEncode):
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
|
||||
@ -41,6 +41,7 @@ from .rec_vl_loss import VLLoss
|
||||
from .rec_spin_att_loss import SPINAttentionLoss
|
||||
from .rec_rfl_loss import RFLLoss
|
||||
from .rec_can_loss import CANLoss
|
||||
from .rec_satrn_loss import SATRNLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
@ -73,7 +74,8 @@ def build_loss(config):
|
||||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
|
||||
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss'
|
||||
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss',
|
||||
'SATRNLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
||||
46
ppocr/losses/rec_satrn_loss.py
Normal file
46
ppocr/losses/rec_satrn_loss.py
Normal file
@ -0,0 +1,46 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/module_losses/ce_module_loss.py
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class SATRNLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(SATRNLoss, self).__init__()
|
||||
ignore_index = kwargs.get('ignore_index', 92) # 6626
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
|
||||
reduction="none", ignore_index=ignore_index)
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
predict = predicts[:, :
|
||||
-1, :] # ignore last index of outputs to be in same seq_len with targets
|
||||
label = batch[1].astype(
|
||||
"int64")[:, 1:] # ignore first index of target in loss calculation
|
||||
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
|
||||
1], predict.shape[2]
|
||||
assert len(label.shape) == len(list(predict.shape)) - 1, \
|
||||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||
|
||||
inputs = paddle.reshape(predict, [-1, num_classes])
|
||||
targets = paddle.reshape(label, [-1])
|
||||
loss = self.loss_func(inputs, targets)
|
||||
return {'loss': loss.mean()}
|
||||
@ -44,11 +44,12 @@ def build_backbone(config, model_type):
|
||||
from .rec_vitstr import ViTSTR
|
||||
from .rec_resnet_rfl import ResNetRFL
|
||||
from .rec_densenet import DenseNet
|
||||
from .rec_shallow_cnn import ShallowCNN
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
|
||||
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
|
||||
'DenseNet'
|
||||
'DenseNet', 'ShallowCNN'
|
||||
]
|
||||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
||||
87
ppocr/modeling/backbones/rec_shallow_cnn.py
Normal file
87
ppocr/modeling/backbones/rec_shallow_cnn.py
Normal file
@ -0,0 +1,87 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/backbones/shallow_cnn.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import MaxPool2D
|
||||
from paddle.nn.initializer import KaimingNormal, Uniform, Constant
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
padding,
|
||||
num_groups=1):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm2D(
|
||||
num_filters,
|
||||
weight_attr=ParamAttr(initializer=Uniform(0, 1)),
|
||||
bias_attr=ParamAttr(initializer=Constant(0)))
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv(inputs)
|
||||
y = self.bn(y)
|
||||
y = self.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class ShallowCNN(nn.Layer):
|
||||
def __init__(self, in_channels=1, hidden_dim=512):
|
||||
super().__init__()
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(hidden_dim, int)
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels, 3, hidden_dim // 2, stride=1, padding=1)
|
||||
self.conv2 = ConvBNLayer(
|
||||
hidden_dim // 2, 3, hidden_dim, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = hidden_dim
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.pool(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.pool(x)
|
||||
|
||||
return x
|
||||
@ -40,6 +40,7 @@ def build_head(config):
|
||||
from .rec_visionlan_head import VLHead
|
||||
from .rec_rfl_head import RFLHead
|
||||
from .rec_can_head import CANHead
|
||||
from .rec_satrn_head import SATRNHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
@ -56,7 +57,7 @@ def build_head(config):
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
|
||||
'DRRGHead', 'CANHead'
|
||||
'DRRGHead', 'CANHead', 'SATRNHead'
|
||||
]
|
||||
|
||||
if config['name'] == 'DRRGHead':
|
||||
|
||||
568
ppocr/modeling/heads/rec_satrn_head.py
Normal file
568
ppocr/modeling/heads/rec_satrn_head.py
Normal file
@ -0,0 +1,568 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/encoders/satrn_encoder.py
|
||||
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/decoders/nrtr_decoder.py
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr, reshape, transpose
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import KaimingNormal, Uniform, Constant
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
padding,
|
||||
num_groups=1):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm2D(
|
||||
num_filters,
|
||||
weight_attr=ParamAttr(initializer=Constant(1)),
|
||||
bias_attr=ParamAttr(initializer=Constant(0)))
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv(inputs)
|
||||
y = self.bn(y)
|
||||
y = self.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class SATRNEncoderLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
d_inner=512,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.feed_forward = LocalityAwareFeedforward(
|
||||
d_model, d_inner, dropout=dropout)
|
||||
|
||||
def forward(self, x, h, w, mask=None):
|
||||
n, hw, c = x.shape
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = residual + self.attn(x, x, x, mask)
|
||||
residual = x
|
||||
x = self.norm2(x)
|
||||
x = x.transpose([0, 2, 1]).reshape([n, c, h, w])
|
||||
x = self.feed_forward(x)
|
||||
x = x.reshape([n, c, hw]).transpose([0, 2, 1])
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
class LocalityAwareFeedforward(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
d_in,
|
||||
d_hid,
|
||||
dropout=0.1, ):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBNLayer(d_in, 1, d_hid, stride=1, padding=0)
|
||||
|
||||
self.depthwise_conv = ConvBNLayer(
|
||||
d_hid, 3, d_hid, stride=1, padding=1, num_groups=d_hid)
|
||||
|
||||
self.conv2 = ConvBNLayer(d_hid, 1, d_in, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Adaptive2DPositionalEncoding(nn.Layer):
|
||||
def __init__(self, d_hid=512, n_height=100, n_width=100, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
|
||||
h_position_encoder = h_position_encoder.transpose([1, 0])
|
||||
h_position_encoder = h_position_encoder.reshape([1, d_hid, n_height, 1])
|
||||
|
||||
w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
|
||||
w_position_encoder = w_position_encoder.transpose([1, 0])
|
||||
w_position_encoder = w_position_encoder.reshape([1, d_hid, 1, n_width])
|
||||
|
||||
self.register_buffer('h_position_encoder', h_position_encoder)
|
||||
self.register_buffer('w_position_encoder', w_position_encoder)
|
||||
|
||||
self.h_scale = self.scale_factor_generate(d_hid)
|
||||
self.w_scale = self.scale_factor_generate(d_hid)
|
||||
self.pool = nn.AdaptiveAvgPool2D(1)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
"""Sinusoid position encoding table."""
|
||||
denominator = paddle.to_tensor([
|
||||
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
])
|
||||
denominator = denominator.reshape([1, -1])
|
||||
pos_tensor = paddle.cast(
|
||||
paddle.arange(n_position).unsqueeze(-1), 'float32')
|
||||
sinusoid_table = pos_tensor * denominator
|
||||
sinusoid_table[:, 0::2] = paddle.sin(sinusoid_table[:, 0::2])
|
||||
sinusoid_table[:, 1::2] = paddle.cos(sinusoid_table[:, 1::2])
|
||||
|
||||
return sinusoid_table
|
||||
|
||||
def scale_factor_generate(self, d_hid):
|
||||
scale_factor = nn.Sequential(
|
||||
nn.Conv2D(d_hid, d_hid, 1),
|
||||
nn.ReLU(), nn.Conv2D(d_hid, d_hid, 1), nn.Sigmoid())
|
||||
|
||||
return scale_factor
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
avg_pool = self.pool(x)
|
||||
|
||||
h_pos_encoding = \
|
||||
self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
|
||||
w_pos_encoding = \
|
||||
self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
|
||||
|
||||
out = x + h_pos_encoding + w_pos_encoding
|
||||
|
||||
out = self.dropout(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ScaledDotProductAttention(nn.Layer):
|
||||
def __init__(self, temperature, attn_dropout=0.1):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.dropout = nn.Dropout(attn_dropout)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
def masked_fill(x, mask, value):
|
||||
y = paddle.full(x.shape, value, x.dtype)
|
||||
return paddle.where(mask, y, x)
|
||||
|
||||
attn = paddle.matmul(q / self.temperature, k.transpose([0, 1, 3, 2]))
|
||||
if mask is not None:
|
||||
attn = masked_fill(attn, mask == 0, -1e9)
|
||||
# attn = attn.masked_fill(mask == 0, float('-inf'))
|
||||
# attn += mask
|
||||
|
||||
attn = self.dropout(F.softmax(attn, axis=-1))
|
||||
output = paddle.matmul(attn, v)
|
||||
|
||||
return output, attn
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Layer):
|
||||
def __init__(self,
|
||||
n_head=8,
|
||||
d_model=512,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
|
||||
self.dim_k = n_head * d_k
|
||||
self.dim_v = n_head * d_v
|
||||
|
||||
self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias_attr=qkv_bias)
|
||||
self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias_attr=qkv_bias)
|
||||
self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias_attr=qkv_bias)
|
||||
|
||||
self.attention = ScaledDotProductAttention(d_k**0.5, dropout)
|
||||
|
||||
self.fc = nn.Linear(self.dim_v, d_model, bias_attr=qkv_bias)
|
||||
self.proj_drop = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
batch_size, len_q, _ = q.shape
|
||||
_, len_k, _ = k.shape
|
||||
|
||||
q = self.linear_q(q).reshape([batch_size, len_q, self.n_head, self.d_k])
|
||||
k = self.linear_k(k).reshape([batch_size, len_k, self.n_head, self.d_k])
|
||||
v = self.linear_v(v).reshape([batch_size, len_k, self.n_head, self.d_v])
|
||||
|
||||
q, k, v = q.transpose([0, 2, 1, 3]), k.transpose(
|
||||
[0, 2, 1, 3]), v.transpose([0, 2, 1, 3])
|
||||
|
||||
if mask is not None:
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
elif mask.dim() == 2:
|
||||
mask = mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
attn_out, _ = self.attention(q, k, v, mask=mask)
|
||||
|
||||
attn_out = attn_out.transpose([0, 2, 1, 3]).reshape(
|
||||
[batch_size, len_q, self.dim_v])
|
||||
|
||||
attn_out = self.fc(attn_out)
|
||||
attn_out = self.proj_drop(attn_out)
|
||||
|
||||
return attn_out
|
||||
|
||||
|
||||
class SATRNEncoder(nn.Layer):
|
||||
def __init__(self,
|
||||
n_layers=12,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
d_model=512,
|
||||
n_position=100,
|
||||
d_inner=256,
|
||||
dropout=0.1):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.position_enc = Adaptive2DPositionalEncoding(
|
||||
d_hid=d_model,
|
||||
n_height=n_position,
|
||||
n_width=n_position,
|
||||
dropout=dropout)
|
||||
self.layer_stack = nn.LayerList([
|
||||
SATRNEncoderLayer(
|
||||
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, feat, valid_ratios=None):
|
||||
"""
|
||||
Args:
|
||||
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
|
||||
img_metas (dict): A dict that contains meta information of input
|
||||
images. Preferably with the key ``valid_ratio``.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of shape :math:`(N, T, D_m)`.
|
||||
"""
|
||||
if valid_ratios is None:
|
||||
valid_ratios = [1.0 for _ in range(feat.shape[0])]
|
||||
feat = self.position_enc(feat)
|
||||
n, c, h, w = feat.shape
|
||||
|
||||
mask = paddle.zeros((n, h, w))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
mask[i, :, :valid_width] = 1
|
||||
|
||||
mask = mask.reshape([n, h * w])
|
||||
feat = feat.reshape([n, c, h * w])
|
||||
|
||||
output = feat.transpose([0, 2, 1])
|
||||
for enc_layer in self.layer_stack:
|
||||
output = enc_layer(output, h, w, mask)
|
||||
output = self.layer_norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Layer):
|
||||
def __init__(self, d_in, d_hid, dropout=0.1):
|
||||
super().__init__()
|
||||
self.w_1 = nn.Linear(d_in, d_hid)
|
||||
self.w_2 = nn.Linear(d_hid, d_in)
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.w_1(x)
|
||||
x = self.act(x)
|
||||
x = self.w_2(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Layer):
|
||||
def __init__(self, d_hid=512, n_position=200, dropout=0):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Not a parameter
|
||||
# Position table of shape (1, n_position, d_hid)
|
||||
self.register_buffer(
|
||||
'position_table',
|
||||
self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
"""Sinusoid position encoding table."""
|
||||
denominator = paddle.to_tensor([
|
||||
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
])
|
||||
denominator = denominator.reshape([1, -1])
|
||||
pos_tensor = paddle.cast(
|
||||
paddle.arange(n_position).unsqueeze(-1), 'float32')
|
||||
sinusoid_table = pos_tensor * denominator
|
||||
sinusoid_table[:, 0::2] = paddle.sin(sinusoid_table[:, 0::2])
|
||||
sinusoid_table[:, 1::2] = paddle.cos(sinusoid_table[:, 1::2])
|
||||
|
||||
return sinusoid_table.unsqueeze(0)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = x + self.position_table[:, :x.shape[1]].clone().detach()
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class TFDecoderLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False,
|
||||
operation_order=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
|
||||
self.self_attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
|
||||
|
||||
self.enc_attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
|
||||
|
||||
self.mlp = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
|
||||
|
||||
self.operation_order = operation_order
|
||||
if self.operation_order is None:
|
||||
self.operation_order = ('norm', 'self_attn', 'norm', 'enc_dec_attn',
|
||||
'norm', 'ffn')
|
||||
assert self.operation_order in [
|
||||
('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'),
|
||||
('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm')
|
||||
]
|
||||
|
||||
def forward(self,
|
||||
dec_input,
|
||||
enc_output,
|
||||
self_attn_mask=None,
|
||||
dec_enc_attn_mask=None):
|
||||
if self.operation_order == ('self_attn', 'norm', 'enc_dec_attn', 'norm',
|
||||
'ffn', 'norm'):
|
||||
dec_attn_out = self.self_attn(dec_input, dec_input, dec_input,
|
||||
self_attn_mask)
|
||||
dec_attn_out += dec_input
|
||||
dec_attn_out = self.norm1(dec_attn_out)
|
||||
|
||||
enc_dec_attn_out = self.enc_attn(dec_attn_out, enc_output,
|
||||
enc_output, dec_enc_attn_mask)
|
||||
enc_dec_attn_out += dec_attn_out
|
||||
enc_dec_attn_out = self.norm2(enc_dec_attn_out)
|
||||
|
||||
mlp_out = self.mlp(enc_dec_attn_out)
|
||||
mlp_out += enc_dec_attn_out
|
||||
mlp_out = self.norm3(mlp_out)
|
||||
elif self.operation_order == ('norm', 'self_attn', 'norm',
|
||||
'enc_dec_attn', 'norm', 'ffn'):
|
||||
dec_input_norm = self.norm1(dec_input)
|
||||
dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm,
|
||||
dec_input_norm, self_attn_mask)
|
||||
dec_attn_out += dec_input
|
||||
|
||||
enc_dec_attn_in = self.norm2(dec_attn_out)
|
||||
enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output,
|
||||
enc_output, dec_enc_attn_mask)
|
||||
enc_dec_attn_out += dec_attn_out
|
||||
|
||||
mlp_out = self.mlp(self.norm3(enc_dec_attn_out))
|
||||
mlp_out += enc_dec_attn_out
|
||||
|
||||
return mlp_out
|
||||
|
||||
|
||||
class SATRNDecoder(nn.Layer):
|
||||
def __init__(self,
|
||||
n_layers=6,
|
||||
d_embedding=512,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
n_position=200,
|
||||
dropout=0.1,
|
||||
num_classes=93,
|
||||
max_seq_len=40,
|
||||
start_idx=1,
|
||||
padding_idx=92):
|
||||
super().__init__()
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.start_idx = start_idx
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.trg_word_emb = nn.Embedding(
|
||||
num_classes, d_embedding, padding_idx=padding_idx)
|
||||
|
||||
self.position_enc = PositionalEncoding(
|
||||
d_embedding, n_position=n_position)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
self.layer_stack = nn.LayerList([
|
||||
TFDecoderLayer(
|
||||
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
|
||||
|
||||
pred_num_class = num_classes - 1 # ignore padding_idx
|
||||
self.classifier = nn.Linear(d_model, pred_num_class)
|
||||
|
||||
@staticmethod
|
||||
def get_pad_mask(seq, pad_idx):
|
||||
|
||||
return (seq != pad_idx).unsqueeze(-2)
|
||||
|
||||
@staticmethod
|
||||
def get_subsequent_mask(seq):
|
||||
"""For masking out the subsequent info."""
|
||||
len_s = seq.shape[1]
|
||||
subsequent_mask = 1 - paddle.triu(
|
||||
paddle.ones((len_s, len_s)), diagonal=1)
|
||||
subsequent_mask = paddle.cast(subsequent_mask.unsqueeze(0), 'bool')
|
||||
|
||||
return subsequent_mask
|
||||
|
||||
def _attention(self, trg_seq, src, src_mask=None):
|
||||
trg_embedding = self.trg_word_emb(trg_seq)
|
||||
trg_pos_encoded = self.position_enc(trg_embedding)
|
||||
tgt = self.dropout(trg_pos_encoded)
|
||||
|
||||
trg_mask = self.get_pad_mask(
|
||||
trg_seq,
|
||||
pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq)
|
||||
output = tgt
|
||||
for dec_layer in self.layer_stack:
|
||||
output = dec_layer(
|
||||
output,
|
||||
src,
|
||||
self_attn_mask=trg_mask,
|
||||
dec_enc_attn_mask=src_mask)
|
||||
output = self.layer_norm(output)
|
||||
|
||||
return output
|
||||
|
||||
def _get_mask(self, logit, valid_ratios):
|
||||
N, T, _ = logit.shape
|
||||
mask = None
|
||||
if valid_ratios is not None:
|
||||
mask = paddle.zeros((N, T))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(T, math.ceil(T * valid_ratio))
|
||||
mask[i, :valid_width] = 1
|
||||
|
||||
return mask
|
||||
|
||||
def forward_train(self, feat, out_enc, targets, valid_ratio):
|
||||
src_mask = self._get_mask(out_enc, valid_ratio)
|
||||
attn_output = self._attention(targets, out_enc, src_mask=src_mask)
|
||||
outputs = self.classifier(attn_output)
|
||||
|
||||
return outputs
|
||||
|
||||
def forward_test(self, feat, out_enc, valid_ratio):
|
||||
|
||||
src_mask = self._get_mask(out_enc, valid_ratio)
|
||||
N = out_enc.shape[0]
|
||||
init_target_seq = paddle.full(
|
||||
(N, self.max_seq_len + 1), self.padding_idx, dtype='int64')
|
||||
# bsz * seq_len
|
||||
init_target_seq[:, 0] = self.start_idx
|
||||
|
||||
outputs = []
|
||||
for step in range(0, paddle.to_tensor(self.max_seq_len)):
|
||||
decoder_output = self._attention(
|
||||
init_target_seq, out_enc, src_mask=src_mask)
|
||||
# bsz * seq_len * C
|
||||
step_result = F.softmax(
|
||||
self.classifier(decoder_output[:, step, :]), axis=-1)
|
||||
# bsz * num_classes
|
||||
outputs.append(step_result)
|
||||
step_max_index = paddle.argmax(step_result, axis=-1)
|
||||
init_target_seq[:, step + 1] = step_max_index
|
||||
|
||||
outputs = paddle.stack(outputs, axis=1)
|
||||
|
||||
return outputs
|
||||
|
||||
def forward(self, feat, out_enc, targets=None, valid_ratio=None):
|
||||
if self.training:
|
||||
return self.forward_train(feat, out_enc, targets, valid_ratio)
|
||||
else:
|
||||
return self.forward_test(feat, out_enc, valid_ratio)
|
||||
|
||||
|
||||
class SATRNHead(nn.Layer):
|
||||
def __init__(self, enc_cfg, dec_cfg, **kwargs):
|
||||
super(SATRNHead, self).__init__()
|
||||
|
||||
# encoder module
|
||||
self.encoder = SATRNEncoder(**enc_cfg)
|
||||
|
||||
# decoder module
|
||||
self.decoder = SATRNDecoder(**dec_cfg)
|
||||
|
||||
def forward(self, feat, targets=None):
|
||||
|
||||
if targets is not None:
|
||||
targets, valid_ratio = targets
|
||||
else:
|
||||
targets, valid_ratio = None, None
|
||||
holistic_feat = self.encoder(feat, valid_ratio) # bsz c
|
||||
|
||||
final_out = self.decoder(feat, holistic_feat, targets, valid_ratio)
|
||||
|
||||
return final_out
|
||||
@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
||||
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
||||
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
|
||||
SPINLabelDecode, VLLabelDecode, RFLLabelDecode
|
||||
SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
|
||||
@ -52,7 +52,8 @@ def build_post_process(config, global_config=None):
|
||||
'TableMasterLabelDecode', 'SPINLabelDecode',
|
||||
'DistillationSerPostProcess', 'DistillationRePostProcess',
|
||||
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
|
||||
'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode'
|
||||
'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode',
|
||||
'SATRNLabelDecode'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
||||
@ -568,6 +568,82 @@ class SARLabelDecode(BaseRecLabelDecode):
|
||||
return [self.padding_idx]
|
||||
|
||||
|
||||
class SATRNLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(SATRNLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
self.rm_symbol = kwargs.get('rm_symbol', False)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
beg_end_str = "<BOS/EOS>"
|
||||
unknown_str = "<UKN>"
|
||||
padding_str = "<PAD>"
|
||||
dict_character = dict_character + [unknown_str]
|
||||
self.unknown_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [beg_end_str]
|
||||
self.start_idx = len(dict_character) - 1
|
||||
self.end_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [padding_str]
|
||||
self.padding_idx = len(dict_character) - 1
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] in ignored_tokens:
|
||||
continue
|
||||
if int(text_index[batch_idx][idx]) == int(self.end_idx):
|
||||
if text_prob is None and idx == 0:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
if self.rm_symbol:
|
||||
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
|
||||
text = text.lower()
|
||||
text = comp.sub('', text)
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [self.padding_idx]
|
||||
|
||||
|
||||
class DistillationSARLabelDecode(SARLabelDecode):
|
||||
"""
|
||||
Convert
|
||||
|
||||
@ -105,6 +105,12 @@ def export_single_model(model,
|
||||
shape=[None, 1, 32, 100], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == 'SATRN':
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 32, 100], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "VisionLAN":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
|
||||
@ -106,6 +106,13 @@ class TextRecognizer(object):
|
||||
"character_dict_path": None,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == "SATRN":
|
||||
postprocess_params = {
|
||||
'name': 'SATRNLabelDecode',
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char,
|
||||
"rm_symbol": True
|
||||
}
|
||||
elif self.rec_algorithm == "PREN":
|
||||
postprocess_params = {'name': 'PRENLabelDecode'}
|
||||
elif self.rec_algorithm == "CAN":
|
||||
@ -429,7 +436,7 @@ class TextRecognizer(object):
|
||||
gsrm_slf_attn_bias1_list.append(norm_img[3])
|
||||
gsrm_slf_attn_bias2_list.append(norm_img[4])
|
||||
norm_img_batch.append(norm_img[0])
|
||||
elif self.rec_algorithm == "SVTR":
|
||||
elif self.rec_algorithm in ["SVTR", "SATRN"]:
|
||||
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
|
||||
self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
|
||||
@ -220,7 +220,7 @@ def train(config,
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
extra_input_models = [
|
||||
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
|
||||
"RobustScanner", "RFL", 'DRRG'
|
||||
"RobustScanner", "RFL", 'DRRG', 'SATRN'
|
||||
]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
@ -643,7 +643,7 @@ def preprocess(is_train=False):
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
||||
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN',
|
||||
'Telescope'
|
||||
'Telescope', 'SATRN'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user