mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-10-09 15:06:34 +00:00
Merge pull request #6842 from smilelite/robustscanner_branch
添加robustscanner(第三次)
This commit is contained in:
commit
a485740fa7
109
configs/rec/rec_r31_robustscanner.yml
Normal file
109
configs/rec/rec_r31_robustscanner.yml
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 5
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 20
|
||||||
|
save_model_dir: ./output/rec/rec_r31_robustscanner/
|
||||||
|
save_epoch_step: 1
|
||||||
|
# evaluation is run every 2000 iterations
|
||||||
|
eval_batch_step: [0, 2000]
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: ./inference/rec_inference
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path: ppocr/utils/dict90.txt
|
||||||
|
max_text_length: &max_text_length 40
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
rm_symbol: True
|
||||||
|
save_res_path: ./output/rec/predicts_robustscanner.txt
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
name: Piecewise
|
||||||
|
decay_epochs: [3, 4]
|
||||||
|
values: [0.001, 0.0001, 0.00001]
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: RobustScanner
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNet31
|
||||||
|
init_type: KaimingNormal
|
||||||
|
Head:
|
||||||
|
name: RobustScannerHead
|
||||||
|
enc_outchannles: 128
|
||||||
|
hybrid_dec_rnn_layers: 2
|
||||||
|
hybrid_dec_dropout: 0
|
||||||
|
position_dec_rnn_layers: 2
|
||||||
|
start_idx: 91
|
||||||
|
mask: True
|
||||||
|
padding_idx: 92
|
||||||
|
encode_value: False
|
||||||
|
max_text_length: *max_text_length
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: SARLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: SARLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
is_filter: True
|
||||||
|
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- SARLabelEncode: # Class handling label
|
||||||
|
- RobustScannerRecResizeImg:
|
||||||
|
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
|
||||||
|
width_downsample_ratio: 0.25
|
||||||
|
max_text_length: *max_text_length
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 64
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 8
|
||||||
|
use_shared_memory: False
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/data_lmdb_release/evaluation/
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- SARLabelEncode: # Class handling label
|
||||||
|
- RobustScannerRecResizeImg:
|
||||||
|
image_shape: [3, 48, 48, 160]
|
||||||
|
max_text_length: *max_text_length
|
||||||
|
width_downsample_ratio: 0.25
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 64
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: False
|
||||||
|
|
@ -72,6 +72,7 @@
|
|||||||
- [x] [ABINet](./algorithm_rec_abinet.md)
|
- [x] [ABINet](./algorithm_rec_abinet.md)
|
||||||
- [x] [VisionLAN](./algorithm_rec_visionlan.md)
|
- [x] [VisionLAN](./algorithm_rec_visionlan.md)
|
||||||
- [x] [SPIN](./algorithm_rec_spin.md)
|
- [x] [SPIN](./algorithm_rec_spin.md)
|
||||||
|
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
|
||||||
|
|
||||||
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||||
|
|
||||||
@ -94,6 +95,7 @@
|
|||||||
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|
||||||
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|
||||||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|
||||||
|
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
|
||||||
|
|
||||||
|
|
||||||
<a name="2"></a>
|
<a name="2"></a>
|
||||||
|
113
doc/doc_ch/algorithm_rec_robustscanner.md
Normal file
113
doc/doc_ch/algorithm_rec_robustscanner.md
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
# RobustScanner
|
||||||
|
|
||||||
|
- [1. 算法简介](#1)
|
||||||
|
- [2. 环境配置](#2)
|
||||||
|
- [3. 模型训练、评估、预测](#3)
|
||||||
|
- [3.1 训练](#3-1)
|
||||||
|
- [3.2 评估](#3-2)
|
||||||
|
- [3.3 预测](#3-3)
|
||||||
|
- [4. 推理部署](#4)
|
||||||
|
- [4.1 Python推理](#4-1)
|
||||||
|
- [4.2 C++推理](#4-2)
|
||||||
|
- [4.3 Serving服务化部署](#4-3)
|
||||||
|
- [4.4 更多推理部署](#4-4)
|
||||||
|
- [5. FAQ](#5)
|
||||||
|
|
||||||
|
<a name="1"></a>
|
||||||
|
## 1. 算法简介
|
||||||
|
|
||||||
|
论文信息:
|
||||||
|
> [RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition](https://arxiv.org/pdf/2007.07542.pdf)
|
||||||
|
> Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne
|
||||||
|
Zhang
|
||||||
|
> ECCV, 2020
|
||||||
|
|
||||||
|
使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|
||||||
|
|
||||||
|
|模型|骨干网络|配置文件|Acc|下载链接|
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
|
||||||
|
|
||||||
|
注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。
|
||||||
|
|
||||||
|
<a name="2"></a>
|
||||||
|
## 2. 环境配置
|
||||||
|
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||||
|
|
||||||
|
|
||||||
|
<a name="3"></a>
|
||||||
|
## 3. 模型训练、评估、预测
|
||||||
|
|
||||||
|
请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
|
||||||
|
|
||||||
|
训练
|
||||||
|
|
||||||
|
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
|
||||||
|
|
||||||
|
```
|
||||||
|
#单卡训练(训练周期长,不建议)
|
||||||
|
python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml
|
||||||
|
|
||||||
|
#多卡训练,通过--gpus参数指定卡号
|
||||||
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
评估
|
||||||
|
|
||||||
|
```
|
||||||
|
# GPU 评估, Global.pretrained_model 为待测权重
|
||||||
|
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||||
|
```
|
||||||
|
|
||||||
|
预测:
|
||||||
|
|
||||||
|
```
|
||||||
|
# 预测使用的配置文件必须与训练一致
|
||||||
|
python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
|
||||||
|
```
|
||||||
|
|
||||||
|
<a name="4"></a>
|
||||||
|
## 4. 推理部署
|
||||||
|
|
||||||
|
<a name="4-1"></a>
|
||||||
|
### 4.1 Python推理
|
||||||
|
首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换:
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner
|
||||||
|
```
|
||||||
|
RobustScanner文本识别模型推理,可以执行如下命令:
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r31_robustscanner/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="RobustScanner" --rec_char_dict_path="ppocr/utils/dict90.txt" --use_space_char=False
|
||||||
|
```
|
||||||
|
|
||||||
|
<a name="4-2"></a>
|
||||||
|
### 4.2 C++推理
|
||||||
|
|
||||||
|
由于C++预处理后处理还未支持RobustScanner,所以暂未支持
|
||||||
|
|
||||||
|
<a name="4-3"></a>
|
||||||
|
### 4.3 Serving服务化部署
|
||||||
|
|
||||||
|
暂不支持
|
||||||
|
|
||||||
|
<a name="4-4"></a>
|
||||||
|
### 4.4 更多推理部署
|
||||||
|
|
||||||
|
暂不支持
|
||||||
|
|
||||||
|
<a name="5"></a>
|
||||||
|
## 5. FAQ
|
||||||
|
|
||||||
|
|
||||||
|
## 引用
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{2020RobustScanner,
|
||||||
|
title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
|
||||||
|
author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang},
|
||||||
|
journal={ECCV2020},
|
||||||
|
year={2020},
|
||||||
|
}
|
||||||
|
```
|
@ -70,6 +70,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
|
|||||||
- [x] [ABINet](./algorithm_rec_abinet_en.md)
|
- [x] [ABINet](./algorithm_rec_abinet_en.md)
|
||||||
- [x] [VisionLAN](./algorithm_rec_visionlan_en.md)
|
- [x] [VisionLAN](./algorithm_rec_visionlan_en.md)
|
||||||
- [x] [SPIN](./algorithm_rec_spin_en.md)
|
- [x] [SPIN](./algorithm_rec_spin_en.md)
|
||||||
|
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
|
||||||
|
|
||||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||||
|
|
||||||
@ -92,6 +93,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||||||
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|
||||||
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|
||||||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|
||||||
|
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
|
||||||
|
|
||||||
|
|
||||||
<a name="2"></a>
|
<a name="2"></a>
|
||||||
|
114
doc/doc_en/algorithm_rec_robustscanner_en.md
Normal file
114
doc/doc_en/algorithm_rec_robustscanner_en.md
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
# RobustScanner
|
||||||
|
|
||||||
|
- [1. Introduction](#1)
|
||||||
|
- [2. Environment](#2)
|
||||||
|
- [3. Model Training / Evaluation / Prediction](#3)
|
||||||
|
- [3.1 Training](#3-1)
|
||||||
|
- [3.2 Evaluation](#3-2)
|
||||||
|
- [3.3 Prediction](#3-3)
|
||||||
|
- [4. Inference and Deployment](#4)
|
||||||
|
- [4.1 Python Inference](#4-1)
|
||||||
|
- [4.2 C++ Inference](#4-2)
|
||||||
|
- [4.3 Serving](#4-3)
|
||||||
|
- [4.4 More](#4-4)
|
||||||
|
- [5. FAQ](#5)
|
||||||
|
|
||||||
|
<a name="1"></a>
|
||||||
|
## 1. Introduction
|
||||||
|
|
||||||
|
Paper:
|
||||||
|
> [RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition](https://arxiv.org/pdf/2007.07542.pdf)
|
||||||
|
> Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne
|
||||||
|
Zhang
|
||||||
|
> ECCV, 2020
|
||||||
|
|
||||||
|
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
|
||||||
|
|
||||||
|
|Model|Backbone|config|Acc|Download link|
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
|
||||||
|
|
||||||
|
Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper.
|
||||||
|
|
||||||
|
<a name="2"></a>
|
||||||
|
## 2. Environment
|
||||||
|
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||||
|
|
||||||
|
|
||||||
|
<a name="3"></a>
|
||||||
|
## 3. Model Training / Evaluation / Prediction
|
||||||
|
|
||||||
|
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
|
||||||
|
|
||||||
|
Training:
|
||||||
|
|
||||||
|
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
#Single GPU training (long training period, not recommended)
|
||||||
|
python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml
|
||||||
|
|
||||||
|
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||||
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
Evaluation:
|
||||||
|
|
||||||
|
```
|
||||||
|
# GPU evaluation
|
||||||
|
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||||
|
```
|
||||||
|
|
||||||
|
Prediction:
|
||||||
|
|
||||||
|
```
|
||||||
|
# The configuration file used for prediction must match the training
|
||||||
|
python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
|
||||||
|
```
|
||||||
|
|
||||||
|
<a name="4"></a>
|
||||||
|
## 4. Inference and Deployment
|
||||||
|
|
||||||
|
<a name="4-1"></a>
|
||||||
|
### 4.1 Python Inference
|
||||||
|
First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert:
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner
|
||||||
|
```
|
||||||
|
|
||||||
|
For RobustScanner text recognition model inference, the following commands can be executed:
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r31_robustscanner/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="RobustScanner" --rec_char_dict_path="ppocr/utils/dict90.txt" --use_space_char=False
|
||||||
|
```
|
||||||
|
|
||||||
|
<a name="4-2"></a>
|
||||||
|
### 4.2 C++ Inference
|
||||||
|
|
||||||
|
Not supported
|
||||||
|
|
||||||
|
<a name="4-3"></a>
|
||||||
|
### 4.3 Serving
|
||||||
|
|
||||||
|
Not supported
|
||||||
|
|
||||||
|
<a name="4-4"></a>
|
||||||
|
### 4.4 More
|
||||||
|
|
||||||
|
Not supported
|
||||||
|
|
||||||
|
<a name="5"></a>
|
||||||
|
## 5. FAQ
|
||||||
|
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{2020RobustScanner,
|
||||||
|
title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
|
||||||
|
author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang},
|
||||||
|
journal={ECCV2020},
|
||||||
|
year={2020},
|
||||||
|
}
|
||||||
|
```
|
@ -26,8 +26,7 @@ from .make_pse_gt import MakePseGt
|
|||||||
|
|
||||||
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||||
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
||||||
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg
|
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg
|
||||||
|
|
||||||
from .ssl_img_aug import SSLRotateResize
|
from .ssl_img_aug import SSLRotateResize
|
||||||
from .randaugment import RandAugment
|
from .randaugment import RandAugment
|
||||||
from .copy_paste import CopyPaste
|
from .copy_paste import CopyPaste
|
||||||
|
@ -414,6 +414,23 @@ class SVTRRecResizeImg(object):
|
|||||||
data['valid_ratio'] = valid_ratio
|
data['valid_ratio'] = valid_ratio
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
class RobustScannerRecResizeImg(object):
|
||||||
|
def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
|
||||||
|
self.image_shape = image_shape
|
||||||
|
self.width_downsample_ratio = width_downsample_ratio
|
||||||
|
self.max_text_length = max_text_length
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
|
||||||
|
img, self.image_shape, self.width_downsample_ratio)
|
||||||
|
word_positons = np.array(range(0, self.max_text_length)).astype('int64')
|
||||||
|
data['image'] = norm_img
|
||||||
|
data['resized_shape'] = resize_shape
|
||||||
|
data['pad_shape'] = pad_shape
|
||||||
|
data['valid_ratio'] = valid_ratio
|
||||||
|
data['word_positons'] = word_positons
|
||||||
|
return data
|
||||||
|
|
||||||
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||||
|
@ -29,27 +29,29 @@ import numpy as np
|
|||||||
|
|
||||||
__all__ = ["ResNet31"]
|
__all__ = ["ResNet31"]
|
||||||
|
|
||||||
|
def conv3x3(in_channel, out_channel, stride=1, conv_weight_attr=None):
|
||||||
def conv3x3(in_channel, out_channel, stride=1):
|
|
||||||
return nn.Conv2D(
|
return nn.Conv2D(
|
||||||
in_channel,
|
in_channel,
|
||||||
out_channel,
|
out_channel,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=1,
|
padding=1,
|
||||||
|
weight_attr=conv_weight_attr,
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock(nn.Layer):
|
class BasicBlock(nn.Layer):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, in_channels, channels, stride=1, downsample=False):
|
def __init__(self, in_channels, channels, stride=1, downsample=False, conv_weight_attr=None, bn_weight_attr=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = conv3x3(in_channels, channels, stride)
|
self.conv1 = conv3x3(in_channels, channels, stride,
|
||||||
self.bn1 = nn.BatchNorm2D(channels)
|
conv_weight_attr=conv_weight_attr)
|
||||||
|
self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.conv2 = conv3x3(channels, channels)
|
self.conv2 = conv3x3(channels, channels,
|
||||||
self.bn2 = nn.BatchNorm2D(channels)
|
conv_weight_attr=conv_weight_attr)
|
||||||
|
self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
if downsample:
|
if downsample:
|
||||||
self.downsample = nn.Sequential(
|
self.downsample = nn.Sequential(
|
||||||
@ -58,8 +60,9 @@ class BasicBlock(nn.Layer):
|
|||||||
channels * self.expansion,
|
channels * self.expansion,
|
||||||
1,
|
1,
|
||||||
stride,
|
stride,
|
||||||
|
weight_attr=conv_weight_attr,
|
||||||
bias_attr=False),
|
bias_attr=False),
|
||||||
nn.BatchNorm2D(channels * self.expansion), )
|
nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr))
|
||||||
else:
|
else:
|
||||||
self.downsample = nn.Sequential()
|
self.downsample = nn.Sequential()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
@ -91,6 +94,7 @@ class ResNet31(nn.Layer):
|
|||||||
channels (list[int]): List of out_channels of Conv2d layer.
|
channels (list[int]): List of out_channels of Conv2d layer.
|
||||||
out_indices (None | Sequence[int]): Indices of output stages.
|
out_indices (None | Sequence[int]): Indices of output stages.
|
||||||
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
|
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
|
||||||
|
init_type (None | str): the config to control the initialization.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -98,7 +102,8 @@ class ResNet31(nn.Layer):
|
|||||||
layers=[1, 2, 5, 3],
|
layers=[1, 2, 5, 3],
|
||||||
channels=[64, 128, 256, 256, 512, 512, 512],
|
channels=[64, 128, 256, 256, 512, 512, 512],
|
||||||
out_indices=None,
|
out_indices=None,
|
||||||
last_stage_pool=False):
|
last_stage_pool=False,
|
||||||
|
init_type=None):
|
||||||
super(ResNet31, self).__init__()
|
super(ResNet31, self).__init__()
|
||||||
assert isinstance(in_channels, int)
|
assert isinstance(in_channels, int)
|
||||||
assert isinstance(last_stage_pool, bool)
|
assert isinstance(last_stage_pool, bool)
|
||||||
@ -106,42 +111,55 @@ class ResNet31(nn.Layer):
|
|||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
self.last_stage_pool = last_stage_pool
|
self.last_stage_pool = last_stage_pool
|
||||||
|
|
||||||
|
conv_weight_attr = None
|
||||||
|
bn_weight_attr = None
|
||||||
|
|
||||||
|
if init_type is not None:
|
||||||
|
support_dict = ['KaimingNormal']
|
||||||
|
assert init_type in support_dict, Exception(
|
||||||
|
"resnet31 only support {}".format(support_dict))
|
||||||
|
conv_weight_attr = nn.initializer.KaimingNormal()
|
||||||
|
bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1)
|
||||||
|
|
||||||
# conv 1 (Conv Conv)
|
# conv 1 (Conv Conv)
|
||||||
self.conv1_1 = nn.Conv2D(
|
self.conv1_1 = nn.Conv2D(
|
||||||
in_channels, channels[0], kernel_size=3, stride=1, padding=1)
|
in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
|
||||||
self.bn1_1 = nn.BatchNorm2D(channels[0])
|
self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr)
|
||||||
self.relu1_1 = nn.ReLU()
|
self.relu1_1 = nn.ReLU()
|
||||||
|
|
||||||
self.conv1_2 = nn.Conv2D(
|
self.conv1_2 = nn.Conv2D(
|
||||||
channels[0], channels[1], kernel_size=3, stride=1, padding=1)
|
channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
|
||||||
self.bn1_2 = nn.BatchNorm2D(channels[1])
|
self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr)
|
||||||
self.relu1_2 = nn.ReLU()
|
self.relu1_2 = nn.ReLU()
|
||||||
|
|
||||||
# conv 2 (Max-pooling, Residual block, Conv)
|
# conv 2 (Max-pooling, Residual block, Conv)
|
||||||
self.pool2 = nn.MaxPool2D(
|
self.pool2 = nn.MaxPool2D(
|
||||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||||
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
|
self.block2 = self._make_layer(channels[1], channels[2], layers[0],
|
||||||
|
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
|
||||||
self.conv2 = nn.Conv2D(
|
self.conv2 = nn.Conv2D(
|
||||||
channels[2], channels[2], kernel_size=3, stride=1, padding=1)
|
channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
|
||||||
self.bn2 = nn.BatchNorm2D(channels[2])
|
self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr)
|
||||||
self.relu2 = nn.ReLU()
|
self.relu2 = nn.ReLU()
|
||||||
|
|
||||||
# conv 3 (Max-pooling, Residual block, Conv)
|
# conv 3 (Max-pooling, Residual block, Conv)
|
||||||
self.pool3 = nn.MaxPool2D(
|
self.pool3 = nn.MaxPool2D(
|
||||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||||
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
|
self.block3 = self._make_layer(channels[2], channels[3], layers[1],
|
||||||
|
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
|
||||||
self.conv3 = nn.Conv2D(
|
self.conv3 = nn.Conv2D(
|
||||||
channels[3], channels[3], kernel_size=3, stride=1, padding=1)
|
channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
|
||||||
self.bn3 = nn.BatchNorm2D(channels[3])
|
self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr)
|
||||||
self.relu3 = nn.ReLU()
|
self.relu3 = nn.ReLU()
|
||||||
|
|
||||||
# conv 4 (Max-pooling, Residual block, Conv)
|
# conv 4 (Max-pooling, Residual block, Conv)
|
||||||
self.pool4 = nn.MaxPool2D(
|
self.pool4 = nn.MaxPool2D(
|
||||||
kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
|
kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
|
||||||
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
|
self.block4 = self._make_layer(channels[3], channels[4], layers[2],
|
||||||
|
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
|
||||||
self.conv4 = nn.Conv2D(
|
self.conv4 = nn.Conv2D(
|
||||||
channels[4], channels[4], kernel_size=3, stride=1, padding=1)
|
channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
|
||||||
self.bn4 = nn.BatchNorm2D(channels[4])
|
self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr)
|
||||||
self.relu4 = nn.ReLU()
|
self.relu4 = nn.ReLU()
|
||||||
|
|
||||||
# conv 5 ((Max-pooling), Residual block, Conv)
|
# conv 5 ((Max-pooling), Residual block, Conv)
|
||||||
@ -149,15 +167,16 @@ class ResNet31(nn.Layer):
|
|||||||
if self.last_stage_pool:
|
if self.last_stage_pool:
|
||||||
self.pool5 = nn.MaxPool2D(
|
self.pool5 = nn.MaxPool2D(
|
||||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||||
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
|
self.block5 = self._make_layer(channels[4], channels[5], layers[3],
|
||||||
|
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
|
||||||
self.conv5 = nn.Conv2D(
|
self.conv5 = nn.Conv2D(
|
||||||
channels[5], channels[5], kernel_size=3, stride=1, padding=1)
|
channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
|
||||||
self.bn5 = nn.BatchNorm2D(channels[5])
|
self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr)
|
||||||
self.relu5 = nn.ReLU()
|
self.relu5 = nn.ReLU()
|
||||||
|
|
||||||
self.out_channels = channels[-1]
|
self.out_channels = channels[-1]
|
||||||
|
|
||||||
def _make_layer(self, input_channels, output_channels, blocks):
|
def _make_layer(self, input_channels, output_channels, blocks, conv_weight_attr=None, bn_weight_attr=None):
|
||||||
layers = []
|
layers = []
|
||||||
for _ in range(blocks):
|
for _ in range(blocks):
|
||||||
downsample = None
|
downsample = None
|
||||||
@ -168,12 +187,14 @@ class ResNet31(nn.Layer):
|
|||||||
output_channels,
|
output_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
weight_attr=conv_weight_attr,
|
||||||
bias_attr=False),
|
bias_attr=False),
|
||||||
nn.BatchNorm2D(output_channels), )
|
nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr))
|
||||||
|
|
||||||
layers.append(
|
layers.append(
|
||||||
BasicBlock(
|
BasicBlock(
|
||||||
input_channels, output_channels, downsample=downsample))
|
input_channels, output_channels, downsample=downsample,
|
||||||
|
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr))
|
||||||
input_channels = output_channels
|
input_channels = output_channels
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ def build_head(config):
|
|||||||
from .rec_multi_head import MultiHead
|
from .rec_multi_head import MultiHead
|
||||||
from .rec_spin_att_head import SPINAttentionHead
|
from .rec_spin_att_head import SPINAttentionHead
|
||||||
from .rec_abinet_head import ABINetHead
|
from .rec_abinet_head import ABINetHead
|
||||||
|
from .rec_robustscanner_head import RobustScannerHead
|
||||||
from .rec_visionlan_head import VLHead
|
from .rec_visionlan_head import VLHead
|
||||||
|
|
||||||
# cls head
|
# cls head
|
||||||
@ -51,7 +52,7 @@ def build_head(config):
|
|||||||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||||
'VLHead', 'SLAHead'
|
'VLHead', 'SLAHead', 'RobustScannerHead'
|
||||||
]
|
]
|
||||||
|
|
||||||
#table head
|
#table head
|
||||||
|
709
ppocr/modeling/heads/rec_robustscanner_head.py
Normal file
709
ppocr/modeling/heads/rec_robustscanner_head.py
Normal file
@ -0,0 +1,709 @@
|
|||||||
|
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/channel_reduction_encoder.py
|
||||||
|
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/robust_scanner_decoder.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
from paddle import ParamAttr
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
class BaseDecoder(nn.Layer):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward_train(self, feat, out_enc, targets, img_metas):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward_test(self, feat, out_enc, img_metas):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
feat,
|
||||||
|
out_enc,
|
||||||
|
label=None,
|
||||||
|
valid_ratios=None,
|
||||||
|
word_positions=None,
|
||||||
|
train_mode=True):
|
||||||
|
self.train_mode = train_mode
|
||||||
|
|
||||||
|
if train_mode:
|
||||||
|
return self.forward_train(feat, out_enc, label, valid_ratios, word_positions)
|
||||||
|
return self.forward_test(feat, out_enc, valid_ratios, word_positions)
|
||||||
|
|
||||||
|
class ChannelReductionEncoder(nn.Layer):
|
||||||
|
"""Change the channel number with a one by one convoluational layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
**kwargs):
|
||||||
|
super(ChannelReductionEncoder, self).__init__()
|
||||||
|
|
||||||
|
self.layer = nn.Conv2D(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0, weight_attr=nn.initializer.XavierNormal())
|
||||||
|
|
||||||
|
def forward(self, feat):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
feat (Tensor): Image features with the shape of
|
||||||
|
:math:`(N, C_{in}, H, W)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`.
|
||||||
|
"""
|
||||||
|
return self.layer(feat)
|
||||||
|
|
||||||
|
|
||||||
|
def masked_fill(x, mask, value):
|
||||||
|
y = paddle.full(x.shape, value, x.dtype)
|
||||||
|
return paddle.where(mask, y, x)
|
||||||
|
|
||||||
|
class DotProductAttentionLayer(nn.Layer):
|
||||||
|
|
||||||
|
def __init__(self, dim_model=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.scale = dim_model**-0.5 if dim_model is not None else 1.
|
||||||
|
|
||||||
|
def forward(self, query, key, value, h, w, valid_ratios=None):
|
||||||
|
query = paddle.transpose(query, (0, 2, 1))
|
||||||
|
logits = paddle.matmul(query, key) * self.scale
|
||||||
|
n, c, t = logits.shape
|
||||||
|
# reshape to (n, c, h, w)
|
||||||
|
logits = paddle.reshape(logits, [n, c, h, w])
|
||||||
|
if valid_ratios is not None:
|
||||||
|
# cal mask of attention weight
|
||||||
|
for i, valid_ratio in enumerate(valid_ratios):
|
||||||
|
valid_width = min(w, int(w * valid_ratio + 0.5))
|
||||||
|
if valid_width < w:
|
||||||
|
logits[i, :, :, valid_width:] = float('-inf')
|
||||||
|
|
||||||
|
# reshape to (n, c, h, w)
|
||||||
|
logits = paddle.reshape(logits, [n, c, t])
|
||||||
|
weights = F.softmax(logits, axis=2)
|
||||||
|
value = paddle.transpose(value, (0, 2, 1))
|
||||||
|
glimpse = paddle.matmul(weights, value)
|
||||||
|
glimpse = paddle.transpose(glimpse, (0, 2, 1))
|
||||||
|
return glimpse
|
||||||
|
|
||||||
|
class SequenceAttentionDecoder(BaseDecoder):
|
||||||
|
"""Sequence attention decoder for RobustScanner.
|
||||||
|
|
||||||
|
RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
|
||||||
|
Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of output classes :math:`C`.
|
||||||
|
rnn_layers (int): Number of RNN layers.
|
||||||
|
dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
|
||||||
|
dim_model (int): Dimension :math:`D_m` of the model. Should also be the
|
||||||
|
same as encoder output vector ``out_enc``.
|
||||||
|
max_seq_len (int): Maximum output sequence length :math:`T`.
|
||||||
|
start_idx (int): The index of `<SOS>`.
|
||||||
|
mask (bool): Whether to mask input features according to
|
||||||
|
``img_meta['valid_ratio']``.
|
||||||
|
padding_idx (int): The index of `<PAD>`.
|
||||||
|
dropout (float): Dropout rate.
|
||||||
|
return_feature (bool): Return feature or logits as the result.
|
||||||
|
encode_value (bool): Whether to use the output of encoder ``out_enc``
|
||||||
|
as `value` of attention layer. If False, the original feature
|
||||||
|
``feat`` will be used.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This decoder will not predict the final class which is assumed to be
|
||||||
|
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
|
||||||
|
is also ignored by loss as specified in
|
||||||
|
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_classes=None,
|
||||||
|
rnn_layers=2,
|
||||||
|
dim_input=512,
|
||||||
|
dim_model=128,
|
||||||
|
max_seq_len=40,
|
||||||
|
start_idx=0,
|
||||||
|
mask=True,
|
||||||
|
padding_idx=None,
|
||||||
|
dropout=0,
|
||||||
|
return_feature=False,
|
||||||
|
encode_value=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.dim_input = dim_input
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.return_feature = return_feature
|
||||||
|
self.encode_value = encode_value
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.start_idx = start_idx
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(
|
||||||
|
self.num_classes, self.dim_model, padding_idx=padding_idx)
|
||||||
|
|
||||||
|
self.sequence_layer = nn.LSTM(
|
||||||
|
input_size=dim_model,
|
||||||
|
hidden_size=dim_model,
|
||||||
|
num_layers=rnn_layers,
|
||||||
|
time_major=False,
|
||||||
|
dropout=dropout)
|
||||||
|
|
||||||
|
self.attention_layer = DotProductAttentionLayer()
|
||||||
|
|
||||||
|
self.prediction = None
|
||||||
|
if not self.return_feature:
|
||||||
|
pred_num_classes = num_classes - 1
|
||||||
|
self.prediction = nn.Linear(
|
||||||
|
dim_model if encode_value else dim_input, pred_num_classes)
|
||||||
|
|
||||||
|
def forward_train(self, feat, out_enc, targets, valid_ratios):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
||||||
|
out_enc (Tensor): Encoder output of shape
|
||||||
|
:math:`(N, D_m, H, W)`.
|
||||||
|
targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a
|
||||||
|
character.
|
||||||
|
valid_ratios (Tensor): valid length ratio of img.
|
||||||
|
Returns:
|
||||||
|
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
|
||||||
|
``return_feature=False``. Otherwise it would be the hidden feature
|
||||||
|
before the prediction projection layer, whose shape is
|
||||||
|
:math:`(N, T, D_m)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tgt_embedding = self.embedding(targets)
|
||||||
|
|
||||||
|
n, c_enc, h, w = out_enc.shape
|
||||||
|
assert c_enc == self.dim_model
|
||||||
|
_, c_feat, _, _ = feat.shape
|
||||||
|
assert c_feat == self.dim_input
|
||||||
|
_, len_q, c_q = tgt_embedding.shape
|
||||||
|
assert c_q == self.dim_model
|
||||||
|
assert len_q <= self.max_seq_len
|
||||||
|
|
||||||
|
query, _ = self.sequence_layer(tgt_embedding)
|
||||||
|
query = paddle.transpose(query, (0, 2, 1))
|
||||||
|
key = paddle.reshape(out_enc, [n, c_enc, h * w])
|
||||||
|
if self.encode_value:
|
||||||
|
value = key
|
||||||
|
else:
|
||||||
|
value = paddle.reshape(feat, [n, c_feat, h * w])
|
||||||
|
|
||||||
|
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
|
||||||
|
attn_out = paddle.transpose(attn_out, (0, 2, 1))
|
||||||
|
|
||||||
|
if self.return_feature:
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
out = self.prediction(attn_out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward_test(self, feat, out_enc, valid_ratios):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
||||||
|
out_enc (Tensor): Encoder output of shape
|
||||||
|
:math:`(N, D_m, H, W)`.
|
||||||
|
valid_ratios (Tensor): valid length ratio of img.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output logit sequence tensor of shape
|
||||||
|
:math:`(N, T, C-1)`.
|
||||||
|
"""
|
||||||
|
seq_len = self.max_seq_len
|
||||||
|
batch_size = feat.shape[0]
|
||||||
|
|
||||||
|
decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for i in range(seq_len):
|
||||||
|
step_out = self.forward_test_step(feat, out_enc, decode_sequence,
|
||||||
|
i, valid_ratios)
|
||||||
|
outputs.append(step_out)
|
||||||
|
max_idx = paddle.argmax(step_out, axis=1, keepdim=False)
|
||||||
|
if i < seq_len - 1:
|
||||||
|
decode_sequence[:, i + 1] = max_idx
|
||||||
|
|
||||||
|
outputs = paddle.stack(outputs, 1)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
|
||||||
|
valid_ratios):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
||||||
|
out_enc (Tensor): Encoder output of shape
|
||||||
|
:math:`(N, D_m, H, W)`.
|
||||||
|
decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that
|
||||||
|
stores history decoding result.
|
||||||
|
current_step (int): Current decoding step.
|
||||||
|
valid_ratios (Tensor): valid length ratio of img
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted
|
||||||
|
tokens at current time step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
embed = self.embedding(decode_sequence)
|
||||||
|
|
||||||
|
n, c_enc, h, w = out_enc.shape
|
||||||
|
assert c_enc == self.dim_model
|
||||||
|
_, c_feat, _, _ = feat.shape
|
||||||
|
assert c_feat == self.dim_input
|
||||||
|
_, _, c_q = embed.shape
|
||||||
|
assert c_q == self.dim_model
|
||||||
|
|
||||||
|
query, _ = self.sequence_layer(embed)
|
||||||
|
query = paddle.transpose(query, (0, 2, 1))
|
||||||
|
key = paddle.reshape(out_enc, [n, c_enc, h * w])
|
||||||
|
if self.encode_value:
|
||||||
|
value = key
|
||||||
|
else:
|
||||||
|
value = paddle.reshape(feat, [n, c_feat, h * w])
|
||||||
|
|
||||||
|
# [n, c, l]
|
||||||
|
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
|
||||||
|
out = attn_out[:, :, current_step]
|
||||||
|
|
||||||
|
if self.return_feature:
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = self.prediction(out)
|
||||||
|
out = F.softmax(out, dim=-1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PositionAwareLayer(nn.Layer):
|
||||||
|
|
||||||
|
def __init__(self, dim_model, rnn_layers=2):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim_model = dim_model
|
||||||
|
|
||||||
|
self.rnn = nn.LSTM(
|
||||||
|
input_size=dim_model,
|
||||||
|
hidden_size=dim_model,
|
||||||
|
num_layers=rnn_layers,
|
||||||
|
time_major=False)
|
||||||
|
|
||||||
|
self.mixer = nn.Sequential(
|
||||||
|
nn.Conv2D(
|
||||||
|
dim_model, dim_model, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2D(
|
||||||
|
dim_model, dim_model, kernel_size=3, stride=1, padding=1))
|
||||||
|
|
||||||
|
def forward(self, img_feature):
|
||||||
|
n, c, h, w = img_feature.shape
|
||||||
|
rnn_input = paddle.transpose(img_feature, (0, 2, 3, 1))
|
||||||
|
rnn_input = paddle.reshape(rnn_input, (n * h, w, c))
|
||||||
|
rnn_output, _ = self.rnn(rnn_input)
|
||||||
|
rnn_output = paddle.reshape(rnn_output, (n, h, w, c))
|
||||||
|
rnn_output = paddle.transpose(rnn_output, (0, 3, 1, 2))
|
||||||
|
out = self.mixer(rnn_output)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PositionAttentionDecoder(BaseDecoder):
|
||||||
|
"""Position attention decoder for RobustScanner.
|
||||||
|
|
||||||
|
RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
|
||||||
|
Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of output classes :math:`C`.
|
||||||
|
rnn_layers (int): Number of RNN layers.
|
||||||
|
dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
|
||||||
|
dim_model (int): Dimension :math:`D_m` of the model. Should also be the
|
||||||
|
same as encoder output vector ``out_enc``.
|
||||||
|
max_seq_len (int): Maximum output sequence length :math:`T`.
|
||||||
|
mask (bool): Whether to mask input features according to
|
||||||
|
``img_meta['valid_ratio']``.
|
||||||
|
return_feature (bool): Return feature or logits as the result.
|
||||||
|
encode_value (bool): Whether to use the output of encoder ``out_enc``
|
||||||
|
as `value` of attention layer. If False, the original feature
|
||||||
|
``feat`` will be used.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This decoder will not predict the final class which is assumed to be
|
||||||
|
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
|
||||||
|
is also ignored by loss
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_classes=None,
|
||||||
|
rnn_layers=2,
|
||||||
|
dim_input=512,
|
||||||
|
dim_model=128,
|
||||||
|
max_seq_len=40,
|
||||||
|
mask=True,
|
||||||
|
return_feature=False,
|
||||||
|
encode_value=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.dim_input = dim_input
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.return_feature = return_feature
|
||||||
|
self.encode_value = encode_value
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
|
||||||
|
|
||||||
|
self.position_aware_module = PositionAwareLayer(
|
||||||
|
self.dim_model, rnn_layers)
|
||||||
|
|
||||||
|
self.attention_layer = DotProductAttentionLayer()
|
||||||
|
|
||||||
|
self.prediction = None
|
||||||
|
if not self.return_feature:
|
||||||
|
pred_num_classes = num_classes - 1
|
||||||
|
self.prediction = nn.Linear(
|
||||||
|
dim_model if encode_value else dim_input, pred_num_classes)
|
||||||
|
|
||||||
|
def _get_position_index(self, length, batch_size):
|
||||||
|
position_index_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
position_index = paddle.arange(0, end=length, step=1, dtype='int64')
|
||||||
|
position_index_list.append(position_index)
|
||||||
|
batch_position_index = paddle.stack(position_index_list, axis=0)
|
||||||
|
return batch_position_index
|
||||||
|
|
||||||
|
def forward_train(self, feat, out_enc, targets, valid_ratios, position_index):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
||||||
|
out_enc (Tensor): Encoder output of shape
|
||||||
|
:math:`(N, D_m, H, W)`.
|
||||||
|
targets (dict): A dict with the key ``padded_targets``, a
|
||||||
|
tensor of shape :math:`(N, T)`. Each element is the index of a
|
||||||
|
character.
|
||||||
|
valid_ratios (Tensor): valid length ratio of img.
|
||||||
|
position_index (Tensor): The position of each word.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
|
||||||
|
``return_feature=False``. Otherwise it will be the hidden feature
|
||||||
|
before the prediction projection layer, whose shape is
|
||||||
|
:math:`(N, T, D_m)`.
|
||||||
|
"""
|
||||||
|
n, c_enc, h, w = out_enc.shape
|
||||||
|
assert c_enc == self.dim_model
|
||||||
|
_, c_feat, _, _ = feat.shape
|
||||||
|
assert c_feat == self.dim_input
|
||||||
|
_, len_q = targets.shape
|
||||||
|
assert len_q <= self.max_seq_len
|
||||||
|
|
||||||
|
position_out_enc = self.position_aware_module(out_enc)
|
||||||
|
|
||||||
|
query = self.embedding(position_index)
|
||||||
|
query = paddle.transpose(query, (0, 2, 1))
|
||||||
|
key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
|
||||||
|
if self.encode_value:
|
||||||
|
value = paddle.reshape(out_enc,(n, c_enc, h * w))
|
||||||
|
else:
|
||||||
|
value = paddle.reshape(feat,(n, c_feat, h * w))
|
||||||
|
|
||||||
|
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
|
||||||
|
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
|
||||||
|
|
||||||
|
if self.return_feature:
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
return self.prediction(attn_out)
|
||||||
|
|
||||||
|
def forward_test(self, feat, out_enc, valid_ratios, position_index):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
||||||
|
out_enc (Tensor): Encoder output of shape
|
||||||
|
:math:`(N, D_m, H, W)`.
|
||||||
|
valid_ratios (Tensor): valid length ratio of img
|
||||||
|
position_index (Tensor): The position of each word.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
|
||||||
|
``return_feature=False``. Otherwise it would be the hidden feature
|
||||||
|
before the prediction projection layer, whose shape is
|
||||||
|
:math:`(N, T, D_m)`.
|
||||||
|
"""
|
||||||
|
n, c_enc, h, w = out_enc.shape
|
||||||
|
assert c_enc == self.dim_model
|
||||||
|
_, c_feat, _, _ = feat.shape
|
||||||
|
assert c_feat == self.dim_input
|
||||||
|
|
||||||
|
position_out_enc = self.position_aware_module(out_enc)
|
||||||
|
|
||||||
|
query = self.embedding(position_index)
|
||||||
|
query = paddle.transpose(query, (0, 2, 1))
|
||||||
|
key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
|
||||||
|
if self.encode_value:
|
||||||
|
value = paddle.reshape(out_enc,(n, c_enc, h * w))
|
||||||
|
else:
|
||||||
|
value = paddle.reshape(feat,(n, c_feat, h * w))
|
||||||
|
|
||||||
|
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
|
||||||
|
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
|
||||||
|
|
||||||
|
if self.return_feature:
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
return self.prediction(attn_out)
|
||||||
|
|
||||||
|
class RobustScannerFusionLayer(nn.Layer):
|
||||||
|
|
||||||
|
def __init__(self, dim_model, dim=-1):
|
||||||
|
super(RobustScannerFusionLayer, self).__init__()
|
||||||
|
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.dim = dim
|
||||||
|
self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2)
|
||||||
|
|
||||||
|
def forward(self, x0, x1):
|
||||||
|
assert x0.shape == x1.shape
|
||||||
|
fusion_input = paddle.concat([x0, x1], self.dim)
|
||||||
|
output = self.linear_layer(fusion_input)
|
||||||
|
output = F.glu(output, self.dim)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class RobustScannerDecoder(BaseDecoder):
|
||||||
|
"""Decoder for RobustScanner.
|
||||||
|
|
||||||
|
RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
|
||||||
|
Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of output classes :math:`C`.
|
||||||
|
dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
|
||||||
|
dim_model (int): Dimension :math:`D_m` of the model. Should also be the
|
||||||
|
same as encoder output vector ``out_enc``.
|
||||||
|
max_seq_len (int): Maximum output sequence length :math:`T`.
|
||||||
|
start_idx (int): The index of `<SOS>`.
|
||||||
|
mask (bool): Whether to mask input features according to
|
||||||
|
``img_meta['valid_ratio']``.
|
||||||
|
padding_idx (int): The index of `<PAD>`.
|
||||||
|
encode_value (bool): Whether to use the output of encoder ``out_enc``
|
||||||
|
as `value` of attention layer. If False, the original feature
|
||||||
|
``feat`` will be used.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This decoder will not predict the final class which is assumed to be
|
||||||
|
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
|
||||||
|
is also ignored by loss as specified in
|
||||||
|
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_classes=None,
|
||||||
|
dim_input=512,
|
||||||
|
dim_model=128,
|
||||||
|
hybrid_decoder_rnn_layers=2,
|
||||||
|
hybrid_decoder_dropout=0,
|
||||||
|
position_decoder_rnn_layers=2,
|
||||||
|
max_seq_len=40,
|
||||||
|
start_idx=0,
|
||||||
|
mask=True,
|
||||||
|
padding_idx=None,
|
||||||
|
encode_value=False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.dim_input = dim_input
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.encode_value = encode_value
|
||||||
|
self.start_idx = start_idx
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
# init hybrid decoder
|
||||||
|
self.hybrid_decoder = SequenceAttentionDecoder(
|
||||||
|
num_classes=num_classes,
|
||||||
|
rnn_layers=hybrid_decoder_rnn_layers,
|
||||||
|
dim_input=dim_input,
|
||||||
|
dim_model=dim_model,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
start_idx=start_idx,
|
||||||
|
mask=mask,
|
||||||
|
padding_idx=padding_idx,
|
||||||
|
dropout=hybrid_decoder_dropout,
|
||||||
|
encode_value=encode_value,
|
||||||
|
return_feature=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# init position decoder
|
||||||
|
self.position_decoder = PositionAttentionDecoder(
|
||||||
|
num_classes=num_classes,
|
||||||
|
rnn_layers=position_decoder_rnn_layers,
|
||||||
|
dim_input=dim_input,
|
||||||
|
dim_model=dim_model,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
mask=mask,
|
||||||
|
encode_value=encode_value,
|
||||||
|
return_feature=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.fusion_module = RobustScannerFusionLayer(
|
||||||
|
self.dim_model if encode_value else dim_input)
|
||||||
|
|
||||||
|
pred_num_classes = num_classes - 1
|
||||||
|
self.prediction = nn.Linear(dim_model if encode_value else dim_input,
|
||||||
|
pred_num_classes)
|
||||||
|
|
||||||
|
def forward_train(self, feat, out_enc, target, valid_ratios, word_positions):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
||||||
|
out_enc (Tensor): Encoder output of shape
|
||||||
|
:math:`(N, D_m, H, W)`.
|
||||||
|
target (dict): A dict with the key ``padded_targets``, a
|
||||||
|
tensor of shape :math:`(N, T)`. Each element is the index of a
|
||||||
|
character.
|
||||||
|
valid_ratios (Tensor):
|
||||||
|
word_positions (Tensor): The position of each word.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`.
|
||||||
|
"""
|
||||||
|
hybrid_glimpse = self.hybrid_decoder.forward_train(
|
||||||
|
feat, out_enc, target, valid_ratios)
|
||||||
|
position_glimpse = self.position_decoder.forward_train(
|
||||||
|
feat, out_enc, target, valid_ratios, word_positions)
|
||||||
|
|
||||||
|
fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
|
||||||
|
|
||||||
|
out = self.prediction(fusion_out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward_test(self, feat, out_enc, valid_ratios, word_positions):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
||||||
|
out_enc (Tensor): Encoder output of shape
|
||||||
|
:math:`(N, D_m, H, W)`.
|
||||||
|
valid_ratios (Tensor):
|
||||||
|
word_positions (Tensor): The position of each word.
|
||||||
|
Returns:
|
||||||
|
Tensor: The output logit sequence tensor of shape
|
||||||
|
:math:`(N, T, C-1)`.
|
||||||
|
"""
|
||||||
|
seq_len = self.max_seq_len
|
||||||
|
batch_size = feat.shape[0]
|
||||||
|
|
||||||
|
decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
|
||||||
|
|
||||||
|
position_glimpse = self.position_decoder.forward_test(
|
||||||
|
feat, out_enc, valid_ratios, word_positions)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for i in range(seq_len):
|
||||||
|
hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
|
||||||
|
feat, out_enc, decode_sequence, i, valid_ratios)
|
||||||
|
|
||||||
|
fusion_out = self.fusion_module(hybrid_glimpse_step,
|
||||||
|
position_glimpse[:, i, :])
|
||||||
|
|
||||||
|
char_out = self.prediction(fusion_out)
|
||||||
|
char_out = F.softmax(char_out, -1)
|
||||||
|
outputs.append(char_out)
|
||||||
|
max_idx = paddle.argmax(char_out, axis=1, keepdim=False)
|
||||||
|
if i < seq_len - 1:
|
||||||
|
decode_sequence[:, i + 1] = max_idx
|
||||||
|
|
||||||
|
outputs = paddle.stack(outputs, 1)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class RobustScannerHead(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
out_channels, # 90 + unknown + start + padding
|
||||||
|
in_channels,
|
||||||
|
enc_outchannles=128,
|
||||||
|
hybrid_dec_rnn_layers=2,
|
||||||
|
hybrid_dec_dropout=0,
|
||||||
|
position_dec_rnn_layers=2,
|
||||||
|
start_idx=0,
|
||||||
|
max_text_length=40,
|
||||||
|
mask=True,
|
||||||
|
padding_idx=None,
|
||||||
|
encode_value=False,
|
||||||
|
**kwargs):
|
||||||
|
super(RobustScannerHead, self).__init__()
|
||||||
|
|
||||||
|
# encoder module
|
||||||
|
self.encoder = ChannelReductionEncoder(
|
||||||
|
in_channels=in_channels, out_channels=enc_outchannles)
|
||||||
|
|
||||||
|
# decoder module
|
||||||
|
self.decoder =RobustScannerDecoder(
|
||||||
|
num_classes=out_channels,
|
||||||
|
dim_input=in_channels,
|
||||||
|
dim_model=enc_outchannles,
|
||||||
|
hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers,
|
||||||
|
hybrid_decoder_dropout=hybrid_dec_dropout,
|
||||||
|
position_decoder_rnn_layers=position_dec_rnn_layers,
|
||||||
|
max_seq_len=max_text_length,
|
||||||
|
start_idx=start_idx,
|
||||||
|
mask=mask,
|
||||||
|
padding_idx=padding_idx,
|
||||||
|
encode_value=encode_value)
|
||||||
|
|
||||||
|
def forward(self, inputs, targets=None):
|
||||||
|
'''
|
||||||
|
targets: [label, valid_ratio, word_positions]
|
||||||
|
'''
|
||||||
|
out_enc = self.encoder(inputs)
|
||||||
|
valid_ratios = None
|
||||||
|
word_positions = targets[-1]
|
||||||
|
|
||||||
|
if len(targets) > 1:
|
||||||
|
valid_ratios = targets[-2]
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
label = targets[0] # label
|
||||||
|
label = paddle.to_tensor(label, dtype='int64')
|
||||||
|
final_out = self.decoder(
|
||||||
|
inputs, out_enc, label, valid_ratios, word_positions)
|
||||||
|
if not self.training:
|
||||||
|
final_out = self.decoder(
|
||||||
|
inputs,
|
||||||
|
out_enc,
|
||||||
|
label=None,
|
||||||
|
valid_ratios=valid_ratios,
|
||||||
|
word_positions=word_positions,
|
||||||
|
train_mode=False)
|
||||||
|
return final_out
|
@ -0,0 +1,111 @@
|
|||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 5
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 20
|
||||||
|
save_model_dir: ./output/rec/rec_r31_robustscanner/
|
||||||
|
save_epoch_step: 1
|
||||||
|
# evaluation is run every 2000 iterations
|
||||||
|
eval_batch_step: [0, 2000]
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: ./inference/rec_inference
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path: ppocr/utils/dict90.txt
|
||||||
|
max_text_length: &max_text_length 40
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
rm_symbol: True
|
||||||
|
save_res_path: ./output/rec/predicts_robustscanner.txt
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
name: Piecewise
|
||||||
|
decay_epochs: [3, 4]
|
||||||
|
values: [0.001, 0.0001, 0.00001]
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: RobustScanner
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNet31
|
||||||
|
init_type: KaimingNormal
|
||||||
|
Head:
|
||||||
|
name: RobustScannerHead
|
||||||
|
enc_outchannles: 128
|
||||||
|
hybrid_dec_rnn_layers: 2
|
||||||
|
hybrid_dec_dropout: 0
|
||||||
|
position_dec_rnn_layers: 2
|
||||||
|
start_idx: 91
|
||||||
|
mask: True
|
||||||
|
padding_idx: 92
|
||||||
|
encode_value: False
|
||||||
|
max_text_length: *max_text_length
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: SARLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: SARLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
is_filter: True
|
||||||
|
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/ic15_data/
|
||||||
|
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- SARLabelEncode: # Class handling label
|
||||||
|
- RobustScannerRecResizeImg:
|
||||||
|
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
|
||||||
|
width_downsample_ratio: 0.25
|
||||||
|
max_text_length: *max_text_length
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 16
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 0
|
||||||
|
use_shared_memory: False
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/ic15_data
|
||||||
|
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- SARLabelEncode: # Class handling label
|
||||||
|
- RobustScannerRecResizeImg:
|
||||||
|
image_shape: [3, 48, 48, 160]
|
||||||
|
max_text_length: *max_text_length
|
||||||
|
width_downsample_ratio: 0.25
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 16
|
||||||
|
num_workers: 0
|
||||||
|
use_shared_memory: False
|
||||||
|
|
@ -0,0 +1,54 @@
|
|||||||
|
===========================train_params===========================
|
||||||
|
model_name:rec_r31_robustscanner
|
||||||
|
python:python
|
||||||
|
gpu_list:0|0,1
|
||||||
|
Global.use_gpu:True|True
|
||||||
|
Global.auto_cast:null
|
||||||
|
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=5
|
||||||
|
Global.save_model_dir:./output/
|
||||||
|
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
|
||||||
|
Global.pretrained_model:null
|
||||||
|
train_model_name:latest
|
||||||
|
train_infer_img_dir:./inference/rec_inference
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
trainer:norm_train
|
||||||
|
norm_train:tools/train.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
|
||||||
|
pact_train:null
|
||||||
|
fpgm_train:null
|
||||||
|
distill_train:null
|
||||||
|
null:null
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================eval_params===========================
|
||||||
|
eval:tools/eval.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================infer_params===========================
|
||||||
|
Global.save_inference_dir:./output/
|
||||||
|
Global.checkpoints:
|
||||||
|
norm_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
|
||||||
|
quant_export:null
|
||||||
|
fpgm_export:null
|
||||||
|
distill_export:null
|
||||||
|
export1:null
|
||||||
|
export2:null
|
||||||
|
##
|
||||||
|
train_model:./inference/rec_r31_robustscanner/best_accuracy
|
||||||
|
infer_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
|
||||||
|
infer_quant:False
|
||||||
|
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.txt --rec_image_shape="3,48,48,160" --use_space_char=False --rec_algorithm="RobustScanner"
|
||||||
|
--use_gpu:True|False
|
||||||
|
--enable_mkldnn:True|False
|
||||||
|
--cpu_threads:1|6
|
||||||
|
--rec_batch_num:1|6
|
||||||
|
--use_tensorrt:False|False
|
||||||
|
--precision:fp32|int8
|
||||||
|
--rec_model_dir:
|
||||||
|
--image_dir:./inference/rec_inference
|
||||||
|
--save_log_path:./test/output/
|
||||||
|
--benchmark:True
|
||||||
|
null:null
|
||||||
|
===========================infer_benchmark_params==========================
|
||||||
|
random_infer_input:[{float32,[3,48,160]}]
|
||||||
|
|
@ -54,6 +54,7 @@
|
|||||||
| NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
| NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||||
| SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
| SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||||
| SPIN |rec_r32_gaspin_bilstm_att | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
| SPIN |rec_r32_gaspin_bilstm_att | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||||
|
| RobustScanner |rec_r31_robustscanner | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||||
| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 <br> 混合精度 | - | - |
|
| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||||
| TableMaster |table_structure_tablemaster_train | 表格识别| 支持 | 多机多卡 <br> 混合精度 | - | - |
|
| TableMaster |table_structure_tablemaster_train | 表格识别| 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ def main():
|
|||||||
config['Architecture']["Head"]['out_channels'] = char_num
|
config['Architecture']["Head"]['out_channels'] = char_num
|
||||||
|
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
|
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"]
|
||||||
extra_input = False
|
extra_input = False
|
||||||
if config['Architecture']['algorithm'] == 'Distillation':
|
if config['Architecture']['algorithm'] == 'Distillation':
|
||||||
for key in config['Architecture']["Models"]:
|
for key in config['Architecture']["Models"]:
|
||||||
|
@ -111,6 +111,22 @@ def export_single_model(model,
|
|||||||
shape=[None, 3, 64, 256], dtype="float32"),
|
shape=[None, 3, 64, 256], dtype="float32"),
|
||||||
]
|
]
|
||||||
model = to_static(model, input_spec=other_shape)
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "RobustScanner":
|
||||||
|
max_text_length = arch_config["Head"]["max_text_length"]
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, 3, 48, 160], dtype="float32"),
|
||||||
|
|
||||||
|
[
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, ],
|
||||||
|
dtype="float32"),
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, max_text_length],
|
||||||
|
dtype="int64")
|
||||||
|
]
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
||||||
input_spec = [
|
input_spec = [
|
||||||
paddle.static.InputSpec(
|
paddle.static.InputSpec(
|
||||||
|
@ -68,7 +68,7 @@ class TextRecognizer(object):
|
|||||||
'name': 'SARLabelDecode',
|
'name': 'SARLabelDecode',
|
||||||
"character_dict_path": args.rec_char_dict_path,
|
"character_dict_path": args.rec_char_dict_path,
|
||||||
"use_space_char": args.use_space_char
|
"use_space_char": args.use_space_char
|
||||||
}
|
}
|
||||||
elif self.rec_algorithm == "VisionLAN":
|
elif self.rec_algorithm == "VisionLAN":
|
||||||
postprocess_params = {
|
postprocess_params = {
|
||||||
'name': 'VLLabelDecode',
|
'name': 'VLLabelDecode',
|
||||||
@ -93,6 +93,13 @@ class TextRecognizer(object):
|
|||||||
"character_dict_path": args.rec_char_dict_path,
|
"character_dict_path": args.rec_char_dict_path,
|
||||||
"use_space_char": args.use_space_char
|
"use_space_char": args.use_space_char
|
||||||
}
|
}
|
||||||
|
elif self.rec_algorithm == "RobustScanner":
|
||||||
|
postprocess_params = {
|
||||||
|
'name': 'SARLabelDecode',
|
||||||
|
"character_dict_path": args.rec_char_dict_path,
|
||||||
|
"use_space_char": args.use_space_char,
|
||||||
|
"rm_symbol": True
|
||||||
|
}
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||||
utility.create_predictor(args, 'rec', logger)
|
utility.create_predictor(args, 'rec', logger)
|
||||||
@ -390,6 +397,18 @@ class TextRecognizer(object):
|
|||||||
img_list[indices[ino]], self.rec_image_shape)
|
img_list[indices[ino]], self.rec_image_shape)
|
||||||
norm_img = norm_img[np.newaxis, :]
|
norm_img = norm_img[np.newaxis, :]
|
||||||
norm_img_batch.append(norm_img)
|
norm_img_batch.append(norm_img)
|
||||||
|
elif self.rec_algorithm == "RobustScanner":
|
||||||
|
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
|
||||||
|
img_list[indices[ino]], self.rec_image_shape, width_downsample_ratio=0.25)
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
valid_ratio = np.expand_dims(valid_ratio, axis=0)
|
||||||
|
valid_ratios = []
|
||||||
|
valid_ratios.append(valid_ratio)
|
||||||
|
norm_img_batch.append(norm_img)
|
||||||
|
word_positions_list = []
|
||||||
|
word_positions = np.array(range(0, 40)).astype('int64')
|
||||||
|
word_positions = np.expand_dims(word_positions, axis=0)
|
||||||
|
word_positions_list.append(word_positions)
|
||||||
else:
|
else:
|
||||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||||
max_wh_ratio)
|
max_wh_ratio)
|
||||||
@ -442,6 +461,35 @@ class TextRecognizer(object):
|
|||||||
np.array(
|
np.array(
|
||||||
[valid_ratios], dtype=np.float32),
|
[valid_ratios], dtype=np.float32),
|
||||||
]
|
]
|
||||||
|
if self.use_onnx:
|
||||||
|
input_dict = {}
|
||||||
|
input_dict[self.input_tensor.name] = norm_img_batch
|
||||||
|
outputs = self.predictor.run(self.output_tensors,
|
||||||
|
input_dict)
|
||||||
|
preds = outputs[0]
|
||||||
|
else:
|
||||||
|
input_names = self.predictor.get_input_names()
|
||||||
|
for i in range(len(input_names)):
|
||||||
|
input_tensor = self.predictor.get_input_handle(
|
||||||
|
input_names[i])
|
||||||
|
input_tensor.copy_from_cpu(inputs[i])
|
||||||
|
self.predictor.run()
|
||||||
|
outputs = []
|
||||||
|
for output_tensor in self.output_tensors:
|
||||||
|
output = output_tensor.copy_to_cpu()
|
||||||
|
outputs.append(output)
|
||||||
|
if self.benchmark:
|
||||||
|
self.autolog.times.stamp()
|
||||||
|
preds = outputs[0]
|
||||||
|
elif self.rec_algorithm == "RobustScanner":
|
||||||
|
valid_ratios = np.concatenate(valid_ratios)
|
||||||
|
word_positions_list = np.concatenate(word_positions_list)
|
||||||
|
inputs = [
|
||||||
|
norm_img_batch,
|
||||||
|
valid_ratios,
|
||||||
|
word_positions_list
|
||||||
|
]
|
||||||
|
|
||||||
if self.use_onnx:
|
if self.use_onnx:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
input_dict[self.input_tensor.name] = norm_img_batch
|
input_dict[self.input_tensor.name] = norm_img_batch
|
||||||
|
@ -96,6 +96,8 @@ def main():
|
|||||||
]
|
]
|
||||||
elif config['Architecture']['algorithm'] == "SAR":
|
elif config['Architecture']['algorithm'] == "SAR":
|
||||||
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
|
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
|
||||||
|
elif config['Architecture']['algorithm'] == "RobustScanner":
|
||||||
|
op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
|
||||||
else:
|
else:
|
||||||
op[op_name]['keep_keys'] = ['image']
|
op[op_name]['keep_keys'] = ['image']
|
||||||
transforms.append(op)
|
transforms.append(op)
|
||||||
@ -131,12 +133,20 @@ def main():
|
|||||||
if config['Architecture']['algorithm'] == "SAR":
|
if config['Architecture']['algorithm'] == "SAR":
|
||||||
valid_ratio = np.expand_dims(batch[-1], axis=0)
|
valid_ratio = np.expand_dims(batch[-1], axis=0)
|
||||||
img_metas = [paddle.to_tensor(valid_ratio)]
|
img_metas = [paddle.to_tensor(valid_ratio)]
|
||||||
|
if config['Architecture']['algorithm'] == "RobustScanner":
|
||||||
|
valid_ratio = np.expand_dims(batch[1], axis=0)
|
||||||
|
word_positons = np.expand_dims(batch[2], axis=0)
|
||||||
|
img_metas = [paddle.to_tensor(valid_ratio),
|
||||||
|
paddle.to_tensor(word_positons),
|
||||||
|
]
|
||||||
images = np.expand_dims(batch[0], axis=0)
|
images = np.expand_dims(batch[0], axis=0)
|
||||||
images = paddle.to_tensor(images)
|
images = paddle.to_tensor(images)
|
||||||
if config['Architecture']['algorithm'] == "SRN":
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
preds = model(images, others)
|
preds = model(images, others)
|
||||||
elif config['Architecture']['algorithm'] == "SAR":
|
elif config['Architecture']['algorithm'] == "SAR":
|
||||||
preds = model(images, img_metas)
|
preds = model(images, img_metas)
|
||||||
|
elif config['Architecture']['algorithm'] == "RobustScanner":
|
||||||
|
preds = model(images, img_metas)
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
post_result = post_process_class(preds)
|
post_result = post_process_class(preds)
|
||||||
|
@ -230,7 +230,7 @@ def train(config,
|
|||||||
|
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
extra_input_models = [
|
extra_input_models = [
|
||||||
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN"
|
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", "RobustScanner"
|
||||||
]
|
]
|
||||||
extra_input = False
|
extra_input = False
|
||||||
if config['Architecture']['algorithm'] == 'Distillation':
|
if config['Architecture']['algorithm'] == 'Distillation':
|
||||||
@ -653,7 +653,7 @@ def preprocess(is_train=False):
|
|||||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
||||||
'Gestalt', 'SLANet'
|
'Gestalt', 'SLANet', 'RobustScanner'
|
||||||
]
|
]
|
||||||
|
|
||||||
if use_xpu:
|
if use_xpu:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user