mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-03 03:09:16 +00:00
Latexocr paddle (#13401)
* commit_test * modified: configs/rec/rec_latex_ocr.yml deleted: ppocr/modeling/backbones/rec_resnetv2.py * ntuple_solve * style * style * style * style * style * style * style * style * style * delete comment * cla_email
This commit is contained in:
parent
c556b9083e
commit
cf26f2330e
126
configs/rec/rec_latex_ocr.yml
Normal file
126
configs/rec/rec_latex_ocr.yml
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 500
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 100
|
||||||
|
save_model_dir: ./output/rec/latex_ocr/
|
||||||
|
save_epoch_step: 5
|
||||||
|
max_seq_len: 512
|
||||||
|
# evaluation is run every 60000 iterations (22 epoch)(batch_size = 56)
|
||||||
|
eval_batch_step: [0, 60000]
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/datasets/pme_demo/0000013.png
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
rec_char_dict_path: ppocr/utils/dict/latex_ocr_tokenizer.json
|
||||||
|
save_res_path: ./output/rec/predicts_latexocr.txt
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: AdamW
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
name: Const
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: LaTeXOCR
|
||||||
|
in_channels: 1
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: HybridTransformer
|
||||||
|
img_size: [192, 672]
|
||||||
|
patch_size: 16
|
||||||
|
num_classes: 0
|
||||||
|
embed_dim: 256
|
||||||
|
depth: 4
|
||||||
|
num_heads: 8
|
||||||
|
input_channel: 1
|
||||||
|
is_predict: False
|
||||||
|
is_export: False
|
||||||
|
Head:
|
||||||
|
name: LaTeXOCRHead
|
||||||
|
pad_value: 0
|
||||||
|
is_export: False
|
||||||
|
decoder_args:
|
||||||
|
attn_on_attn: True
|
||||||
|
cross_attend: True
|
||||||
|
ff_glu: True
|
||||||
|
rel_pos_bias: False
|
||||||
|
use_scalenorm: False
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: LaTeXOCRLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: LaTeXOCRDecode
|
||||||
|
rec_char_dict_path: ppocr/utils/dict/latex_ocr_tokenizer.json
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: LaTeXOCRMetric
|
||||||
|
main_indicator: exp_rate
|
||||||
|
cal_blue_score: False
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LaTeXOCRDataSet
|
||||||
|
data: ./train_data/LaTeXOCR/latexocr_train.pkl
|
||||||
|
min_dimensions: [32, 32]
|
||||||
|
max_dimensions: [672, 192]
|
||||||
|
batch_size_per_pair: 56
|
||||||
|
keep_smaller_batches: False
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
channel_first: False
|
||||||
|
- MinMaxResize:
|
||||||
|
min_dimensions: [32, 32]
|
||||||
|
max_dimensions: [672, 192]
|
||||||
|
- LatexTrainTransform:
|
||||||
|
bitmap_prob: .04
|
||||||
|
- NormalizeImage:
|
||||||
|
mean: [0.7931, 0.7931, 0.7931]
|
||||||
|
std: [0.1738, 0.1738, 0.1738]
|
||||||
|
order: 'hwc'
|
||||||
|
- LatexImageFormat:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image']
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 1
|
||||||
|
drop_last: False
|
||||||
|
num_workers: 0
|
||||||
|
collate_fn: LaTeXOCRCollator
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LaTeXOCRDataSet
|
||||||
|
data: ./train_data/LaTeXOCR/latexocr_val.pkl
|
||||||
|
min_dimensions: [32, 32]
|
||||||
|
max_dimensions: [672, 192]
|
||||||
|
batch_size_per_pair: 10
|
||||||
|
keep_smaller_batches: True
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
channel_first: False
|
||||||
|
- MinMaxResize:
|
||||||
|
min_dimensions: [32, 32]
|
||||||
|
max_dimensions: [672, 192]
|
||||||
|
- LatexTestTransform:
|
||||||
|
- NormalizeImage:
|
||||||
|
mean: [0.7931, 0.7931, 0.7931]
|
||||||
|
std: [0.1738, 0.1738, 0.1738]
|
||||||
|
order: 'hwc'
|
||||||
|
- LatexImageFormat:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 1
|
||||||
|
num_workers: 0
|
||||||
|
collate_fn: LaTeXOCRCollator
|
||||||
BIN
doc/datasets/pme_demo/0000013.png
Normal file
BIN
doc/datasets/pme_demo/0000013.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.5 KiB |
BIN
doc/datasets/pme_demo/0000295.png
Normal file
BIN
doc/datasets/pme_demo/0000295.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.3 KiB |
BIN
doc/datasets/pme_demo/0000562.png
Normal file
BIN
doc/datasets/pme_demo/0000562.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.2 KiB |
@ -137,6 +137,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|
|||||||
|
|
||||||
已支持的公式识别算法列表(戳链接获取使用教程):
|
已支持的公式识别算法列表(戳链接获取使用教程):
|
||||||
- [x] [CAN](./algorithm_rec_can.md)
|
- [x] [CAN](./algorithm_rec_can.md)
|
||||||
|
- [x] [LaTeX-OCR](./algorithm_rec_latex_ocr.md)
|
||||||
|
|
||||||
在CROHME手写公式数据集上,算法效果如下:
|
在CROHME手写公式数据集上,算法效果如下:
|
||||||
|
|
||||||
@ -144,6 +145,13 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|
|||||||
| ----- | ----- | ----- | ----- | ----- |
|
| ----- | ----- | ----- | ----- | ----- |
|
||||||
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_d28_can_train.tar)|
|
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_d28_can_train.tar)|
|
||||||
|
|
||||||
|
在LaTeX-OCR印刷公式数据集上,算法效果如下:
|
||||||
|
|
||||||
|
| 模型 | 骨干网络 |配置文件 | BLEU score | normed edit distance | ExpRate |下载链接|
|
||||||
|
|-----------|------------| ----- |:-----------:|:---------------------:|:---------:| ----- |
|
||||||
|
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|
|
||||||
|
|
||||||
|
|
||||||
<a name="2"></a>
|
<a name="2"></a>
|
||||||
|
|
||||||
## 2. 端到端算法
|
## 2. 端到端算法
|
||||||
|
|||||||
171
doc/doc_ch/algorithm_rec_latex_ocr.md
Normal file
171
doc/doc_ch/algorithm_rec_latex_ocr.md
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
# 印刷数学公式识别算法-LaTeX-OCR
|
||||||
|
|
||||||
|
- [1. 算法简介](#1)
|
||||||
|
- [2. 环境配置](#2)
|
||||||
|
- [3. 模型训练、评估、预测](#3)
|
||||||
|
- [3.1 pickle 标签文件生成](#3-1)
|
||||||
|
- [3.2 训练](#3-2)
|
||||||
|
- [3.3 评估](#3-3)
|
||||||
|
- [3.4 预测](#3-4)
|
||||||
|
- [4. 推理部署](#4)
|
||||||
|
- [4.1 Python推理](#4-1)
|
||||||
|
- [4.2 C++推理](#4-2)
|
||||||
|
- [4.3 Serving服务化部署](#4-3)
|
||||||
|
- [4.4 更多推理部署](#4-4)
|
||||||
|
- [5. FAQ](#5)
|
||||||
|
|
||||||
|
<a name="1"></a>
|
||||||
|
## 1. 算法简介
|
||||||
|
|
||||||
|
原始项目:
|
||||||
|
> [https://github.com/lukas-blecher/LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<a name="model"></a>
|
||||||
|
`LaTeX-OCR`使用[`LaTeX-OCR印刷公式数据集`](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO)进行训练,在对应测试集上的精度如下:
|
||||||
|
|
||||||
|
| 模型 | 骨干网络 |配置文件 | BLEU score | normed edit distance | ExpRate |下载链接|
|
||||||
|
|-----------|------------| ----- |:-----------:|:---------------------:|:---------:| ----- |
|
||||||
|
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|
|
||||||
|
|
||||||
|
<a name="2"></a>
|
||||||
|
## 2. 环境配置
|
||||||
|
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||||
|
|
||||||
|
<a name="3"></a>
|
||||||
|
## 3. 模型训练、评估、预测
|
||||||
|
|
||||||
|
<a name="3-1"></a>
|
||||||
|
|
||||||
|
### 3.1 pickle 标签文件生成
|
||||||
|
从[谷歌云盘](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO)中下载 formulae.zip 和 math.txt,之后,使用如下命令,生成 pickle 标签文件。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# 创建 LaTeX-OCR 数据集目录
|
||||||
|
mkdir -p train_data/LaTeXOCR
|
||||||
|
# 解压formulae.zip ,并拷贝math.txt
|
||||||
|
unzip -d train_data/LaTeXOCR path/formulae.zip
|
||||||
|
cp path/math.txt train_data/LaTeXOCR
|
||||||
|
# 将原始的 .txt 文件转换为 .pkl 文件,从而对不同尺度的图像进行分组
|
||||||
|
# 训练集转换
|
||||||
|
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/train --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
|
||||||
|
# 验证集转换
|
||||||
|
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/val --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
|
||||||
|
# 测试集转换
|
||||||
|
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/test --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 模型训练
|
||||||
|
|
||||||
|
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`LaTeX-OCR`识别模型时需要**更换配置文件**为`LaTeX-OCR`的[配置文件](../../configs/rec/rec_latex_ocr.yml)。
|
||||||
|
|
||||||
|
#### 启动训练
|
||||||
|
|
||||||
|
|
||||||
|
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
|
||||||
|
```shell
|
||||||
|
#单卡训练 (默认训练方式)
|
||||||
|
python3 tools/train.py -c configs/rec/rec_latex_ocr.yml
|
||||||
|
#多卡训练,通过--gpus参数指定卡号
|
||||||
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_latex_ocr.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
**注意:**
|
||||||
|
|
||||||
|
- 默认每训练22个epoch(60000次iteration)进行1次评估,若您更改训练的batch_size,或更换数据集,请在训练时作出如下修改
|
||||||
|
```
|
||||||
|
python3 tools/train.py -c configs/rec/rec_latex_ocr.yml -o Global.eval_batch_step=[0,{length_of_dataset//batch_size*22}]
|
||||||
|
```
|
||||||
|
|
||||||
|
<a name="3-2"></a>
|
||||||
|
### 3.3 评估
|
||||||
|
|
||||||
|
可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar),使用如下命令进行评估:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# 注意将pretrained_model的路径设置为本地路径。若使用自行训练保存的模型,请注意修改路径和文件名为{path/to/weights}/{model_name}。
|
||||||
|
# 验证集评估
|
||||||
|
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True
|
||||||
|
# 测试集评估
|
||||||
|
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl
|
||||||
|
```
|
||||||
|
|
||||||
|
<a name="3-3"></a>
|
||||||
|
### 3.4 预测
|
||||||
|
|
||||||
|
使用如下命令进行单张图片预测:
|
||||||
|
```shell
|
||||||
|
# 注意将pretrained_model的路径设置为本地路径。
|
||||||
|
python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True Global.infer_img='./doc/datasets/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
|
||||||
|
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/datasets/pme_demo/'。
|
||||||
|
```
|
||||||
|
|
||||||
|
<a name="4"></a>
|
||||||
|
## 4. 推理部署
|
||||||
|
|
||||||
|
<a name="4-1"></a>
|
||||||
|
### 4.1 Python推理
|
||||||
|
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar) ),可以使用如下命令进行转换:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# 注意将pretrained_model的路径设置为本地路径。
|
||||||
|
python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/ Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True
|
||||||
|
|
||||||
|
# 目前的静态图模型支持的最大输出长度为512
|
||||||
|
```
|
||||||
|
**注意:**
|
||||||
|
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请检查配置文件中的`rec_char_dict_path`是否为所需要的字典文件。
|
||||||
|
- [转换后模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_infer.tar)
|
||||||
|
|
||||||
|
转换成功后,在目录下有三个文件:
|
||||||
|
```
|
||||||
|
/inference/rec_latex_ocr_infer/
|
||||||
|
├── inference.pdiparams # 识别inference模型的参数文件
|
||||||
|
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
|
||||||
|
└── inference.pdmodel # 识别inference模型的program文件
|
||||||
|
```
|
||||||
|
|
||||||
|
执行如下命令进行模型推理:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python3 tools/infer/predict_rec.py --image_dir='./doc/datasets/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json"
|
||||||
|
|
||||||
|
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/datasets/pme_demo/'。
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
执行命令后,上面图像的预测结果(识别的文本)会打印到屏幕上,示例如下:
|
||||||
|
```shell
|
||||||
|
Predicts of ./doc/datasets/pme_demo/0000295.png:\zeta_{0}(\nu)=-{\frac{\nu\varrho^{-2\nu}}{\pi}}\int_{\mu}^{\infty}d\omega\int_{C_{+}}d z{\frac{2z^{2}}{(z^{2}+\omega^{2})^{\nu+1}}}{\tilde{\Psi}}(\omega;z)e^{i\epsilon z}~~~,
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
**注意**:
|
||||||
|
|
||||||
|
- 需要注意预测图像为**白底黑字**,即手写公式部分为黑色,背景为白色的图片。
|
||||||
|
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
|
||||||
|
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中 LaTeX-OCR 的预处理为您的预处理方法。
|
||||||
|
|
||||||
|
|
||||||
|
<a name="4-2"></a>
|
||||||
|
### 4.2 C++推理部署
|
||||||
|
|
||||||
|
由于C++预处理后处理还未支持 LaTeX-OCR,所以暂未支持
|
||||||
|
|
||||||
|
<a name="4-3"></a>
|
||||||
|
### 4.3 Serving服务化部署
|
||||||
|
|
||||||
|
暂不支持
|
||||||
|
|
||||||
|
<a name="4-4"></a>
|
||||||
|
### 4.4 更多推理部署
|
||||||
|
|
||||||
|
暂不支持
|
||||||
|
|
||||||
|
<a name="5"></a>
|
||||||
|
## 5. FAQ
|
||||||
|
|
||||||
|
1. LaTeX-OCR 数据集来自于[LaTeXOCR源repo](https://github.com/lukas-blecher/LaTeX-OCR) 。
|
||||||
@ -137,6 +137,8 @@ On the TextZoom public dataset, the effect of the algorithm is as follows:
|
|||||||
Supported formula recognition algorithms (Click the link to get the tutorial):
|
Supported formula recognition algorithms (Click the link to get the tutorial):
|
||||||
|
|
||||||
- [x] [CAN](./algorithm_rec_can_en.md)
|
- [x] [CAN](./algorithm_rec_can_en.md)
|
||||||
|
- [x] [LaTeX-OCR](./algorithm_rec_latex_ocr_en.md)
|
||||||
|
|
||||||
|
|
||||||
On the CROHME handwritten formula dataset, the effect of the algorithm is as follows:
|
On the CROHME handwritten formula dataset, the effect of the algorithm is as follows:
|
||||||
|
|
||||||
@ -145,6 +147,13 @@ On the CROHME handwritten formula dataset, the effect of the algorithm is as fol
|
|||||||
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_d28_can_train.tar)|
|
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_d28_can_train.tar)|
|
||||||
|
|
||||||
|
|
||||||
|
On the LaTeX-OCR printed formula dataset, the effect of the algorithm is as follows:
|
||||||
|
|
||||||
|
| Model | Backbone |config| BLEU score | normed edit distance | ExpRate |Download link|
|
||||||
|
|-----------|----------| ---- |:-----------:|:---------------------:|:---------:| ----- |
|
||||||
|
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|
|
||||||
|
|
||||||
|
|
||||||
<a name="2"></a>
|
<a name="2"></a>
|
||||||
|
|
||||||
## 2. End-to-end OCR Algorithms
|
## 2. End-to-end OCR Algorithms
|
||||||
|
|||||||
127
doc/doc_en/algorithm_rec_latex_ocr_en.md
Normal file
127
doc/doc_en/algorithm_rec_latex_ocr_en.md
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
# LaTeX-OCR
|
||||||
|
|
||||||
|
- [1. Introduction](#1)
|
||||||
|
- [2. Environment](#2)
|
||||||
|
- [3. Model Training / Evaluation / Prediction](#3)
|
||||||
|
- [3.1 Pickle File Generation](#3-1)
|
||||||
|
- [3.2 Training](#3-2)
|
||||||
|
- [3.3 Evaluation](#3-3)
|
||||||
|
- [3.4 Prediction](#3-4)
|
||||||
|
- [4. Inference and Deployment](#4)
|
||||||
|
- [4.1 Python Inference](#4-1)
|
||||||
|
- [4.2 C++ Inference](#4-2)
|
||||||
|
- [4.3 Serving](#4-3)
|
||||||
|
- [4.4 More](#4-4)
|
||||||
|
- [5. FAQ](#5)
|
||||||
|
|
||||||
|
<a name="1"></a>
|
||||||
|
## 1. Introduction
|
||||||
|
|
||||||
|
Original Project:
|
||||||
|
> [https://github.com/lukas-blecher/LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)
|
||||||
|
|
||||||
|
|
||||||
|
Using LaTeX-OCR printed mathematical expression recognition datasets for training, and evaluating on its test sets, the algorithm reproduction effect is as follows:
|
||||||
|
|
||||||
|
| Model | Backbone |config| BLEU score | normed edit distance | ExpRate |Download link|
|
||||||
|
|-----------|----------| ---- |:-----------:|:---------------------:|:---------:| ----- |
|
||||||
|
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|
|
||||||
|
|
||||||
|
<a name="2"></a>
|
||||||
|
## 2. Environment
|
||||||
|
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||||
|
|
||||||
|
|
||||||
|
<a name="3"></a>
|
||||||
|
## 3. Model Training / Evaluation / Prediction
|
||||||
|
|
||||||
|
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
|
||||||
|
|
||||||
|
Pickle File Generation:
|
||||||
|
|
||||||
|
Download formulae.zip and math.txt in [Google Drive](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO), and then use the following command to generate the pickle file.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Create a LaTeX-OCR dataset directory
|
||||||
|
mkdir -p train_data/LaTeXOCR
|
||||||
|
# Unzip formulae.zip and copy math.txt
|
||||||
|
unzip -d train_data/LaTeXOCR path/formulae.zip
|
||||||
|
cp path/math.txt train_data/LaTeXOCR
|
||||||
|
# Convert the original .txt file to a .pkl file to group images of different scales
|
||||||
|
# Training set conversion
|
||||||
|
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/train --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
|
||||||
|
# Validation set conversion
|
||||||
|
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/val --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
|
||||||
|
# Test set conversion
|
||||||
|
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/test --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Training:
|
||||||
|
|
||||||
|
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
#Single GPU training (Default training method)
|
||||||
|
python3 tools/train.py -c configs/rec/rec_latex_ocr.yml
|
||||||
|
|
||||||
|
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||||
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_latex_ocr.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
Evaluation:
|
||||||
|
|
||||||
|
```
|
||||||
|
# GPU evaluation
|
||||||
|
# Validation set evaluation
|
||||||
|
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True
|
||||||
|
# Test set evaluation
|
||||||
|
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl
|
||||||
|
```
|
||||||
|
|
||||||
|
Prediction:
|
||||||
|
|
||||||
|
```
|
||||||
|
# The configuration file used for prediction must match the training
|
||||||
|
python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True Global.infer_img='./doc/datasets/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
|
||||||
|
```
|
||||||
|
|
||||||
|
<a name="4"></a>
|
||||||
|
## 4. Inference and Deployment
|
||||||
|
|
||||||
|
<a name="4-1"></a>
|
||||||
|
### 4.1 Python Inference
|
||||||
|
First, the model saved during the LaTeX-OCR printed mathematical expression recognition training process is converted into an inference model. you can use the following command to convert:
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/ Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True
|
||||||
|
|
||||||
|
# The default output max length of the model is 512.
|
||||||
|
```
|
||||||
|
|
||||||
|
For LaTeX-OCR printed mathematical expression recognition model inference, the following commands can be executed:
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 tools/infer/predict_rec.py --image_dir='./doc/datasets/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json"
|
||||||
|
```
|
||||||
|
|
||||||
|
<a name="4-2"></a>
|
||||||
|
### 4.2 C++ Inference
|
||||||
|
|
||||||
|
Not supported
|
||||||
|
|
||||||
|
<a name="4-3"></a>
|
||||||
|
### 4.3 Serving
|
||||||
|
|
||||||
|
Not supported
|
||||||
|
|
||||||
|
<a name="4-4"></a>
|
||||||
|
### 4.4 More
|
||||||
|
|
||||||
|
Not supported
|
||||||
|
|
||||||
|
<a name="5"></a>
|
||||||
|
## 5. FAQ
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
@ -38,6 +38,7 @@ from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR, LMDBDataSetTable
|
|||||||
from ppocr.data.pgnet_dataset import PGDataSet
|
from ppocr.data.pgnet_dataset import PGDataSet
|
||||||
from ppocr.data.pubtab_dataset import PubTabDataSet
|
from ppocr.data.pubtab_dataset import PubTabDataSet
|
||||||
from ppocr.data.multi_scale_sampler import MultiScaleSampler
|
from ppocr.data.multi_scale_sampler import MultiScaleSampler
|
||||||
|
from ppocr.data.latexocr_dataset import LaTeXOCRDataSet
|
||||||
|
|
||||||
# for PaddleX dataset_type
|
# for PaddleX dataset_type
|
||||||
TextDetDataset = SimpleDataSet
|
TextDetDataset = SimpleDataSet
|
||||||
@ -45,6 +46,7 @@ TextRecDataset = SimpleDataSet
|
|||||||
MSTextRecDataset = MultiScaleDataSet
|
MSTextRecDataset = MultiScaleDataSet
|
||||||
PubTabTableRecDataset = PubTabDataSet
|
PubTabTableRecDataset = PubTabDataSet
|
||||||
KieDataset = SimpleDataSet
|
KieDataset = SimpleDataSet
|
||||||
|
LaTeXOCRDataSet = LaTeXOCRDataSet
|
||||||
|
|
||||||
__all__ = ["build_dataloader", "transform", "create_operators", "set_signal_handlers"]
|
__all__ = ["build_dataloader", "transform", "create_operators", "set_signal_handlers"]
|
||||||
|
|
||||||
@ -94,6 +96,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
|
|||||||
"MSTextRecDataset",
|
"MSTextRecDataset",
|
||||||
"PubTabTableRecDataset",
|
"PubTabTableRecDataset",
|
||||||
"KieDataset",
|
"KieDataset",
|
||||||
|
"LaTeXOCRDataSet",
|
||||||
]
|
]
|
||||||
module_name = config[mode]["dataset"]["name"]
|
module_name = config[mode]["dataset"]["name"]
|
||||||
assert module_name in support_dict, Exception(
|
assert module_name in support_dict, Exception(
|
||||||
|
|||||||
@ -116,3 +116,18 @@ class DyMaskCollator(object):
|
|||||||
label_masks[i][:l] = 1
|
label_masks[i][:l] = 1
|
||||||
|
|
||||||
return images, image_masks, labels, label_masks
|
return images, image_masks, labels, label_masks
|
||||||
|
|
||||||
|
|
||||||
|
class LaTeXOCRCollator(object):
|
||||||
|
"""
|
||||||
|
batch: [
|
||||||
|
image [batch_size, channel, maxHinbatch, maxWinbatch]
|
||||||
|
label [batch_size, maxLabelLen]
|
||||||
|
label_mask [batch_size, maxLabelLen]
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
images, labels, attention_mask = batch[0]
|
||||||
|
return images, labels, attention_mask
|
||||||
|
|||||||
@ -61,6 +61,7 @@ from .fce_aug import *
|
|||||||
from .fce_targets import FCENetTargets
|
from .fce_targets import FCENetTargets
|
||||||
from .ct_process import *
|
from .ct_process import *
|
||||||
from .drrg_targets import DRRGTargets
|
from .drrg_targets import DRRGTargets
|
||||||
|
from .latex_ocr_aug import *
|
||||||
|
|
||||||
|
|
||||||
def transform(data, ops=None):
|
def transform(data, ops=None):
|
||||||
|
|||||||
@ -25,6 +25,8 @@ import json
|
|||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
from random import sample
|
from random import sample
|
||||||
|
from collections import defaultdict
|
||||||
|
from tokenizers import Tokenizer as TokenizerFast
|
||||||
|
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
from ppocr.data.imaug.vqa.augment import order_by_tbyx
|
from ppocr.data.imaug.vqa.augment import order_by_tbyx
|
||||||
@ -1770,3 +1772,106 @@ class CPPDLabelEncode(BaseRecLabelEncode):
|
|||||||
if len(text_list) == 0:
|
if len(text_list) == 0:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
return text_list, text_node_index, text_node_num
|
return text_list, text_node_index, text_node_num
|
||||||
|
|
||||||
|
|
||||||
|
class LatexOCRLabelEncode(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rec_char_dict_path,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
|
||||||
|
self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
|
||||||
|
self.pad_token_id = 0
|
||||||
|
self.bos_token_id = 1
|
||||||
|
self.eos_token_id = 2
|
||||||
|
|
||||||
|
def _convert_encoding(
|
||||||
|
self,
|
||||||
|
encoding,
|
||||||
|
return_token_type_ids=None,
|
||||||
|
return_attention_mask=None,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
return_special_tokens_mask=False,
|
||||||
|
return_offsets_mapping=False,
|
||||||
|
return_length=False,
|
||||||
|
verbose=True,
|
||||||
|
):
|
||||||
|
|
||||||
|
if return_token_type_ids is None:
|
||||||
|
return_token_type_ids = "token_type_ids" in self.model_input_names
|
||||||
|
if return_attention_mask is None:
|
||||||
|
return_attention_mask = "attention_mask" in self.model_input_names
|
||||||
|
|
||||||
|
if return_overflowing_tokens and encoding.overflowing is not None:
|
||||||
|
encodings = [encoding] + encoding.overflowing
|
||||||
|
else:
|
||||||
|
encodings = [encoding]
|
||||||
|
|
||||||
|
encoding_dict = defaultdict(list)
|
||||||
|
for e in encodings:
|
||||||
|
encoding_dict["input_ids"].append(e.ids)
|
||||||
|
|
||||||
|
if return_token_type_ids:
|
||||||
|
encoding_dict["token_type_ids"].append(e.type_ids)
|
||||||
|
if return_attention_mask:
|
||||||
|
encoding_dict["attention_mask"].append(e.attention_mask)
|
||||||
|
if return_special_tokens_mask:
|
||||||
|
encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
|
||||||
|
if return_offsets_mapping:
|
||||||
|
encoding_dict["offset_mapping"].append(e.offsets)
|
||||||
|
if return_length:
|
||||||
|
encoding_dict["length"].append(len(e.ids))
|
||||||
|
|
||||||
|
return encoding_dict, encodings
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
text,
|
||||||
|
text_pair=None,
|
||||||
|
return_token_type_ids=False,
|
||||||
|
add_special_tokens=True,
|
||||||
|
is_split_into_words=False,
|
||||||
|
):
|
||||||
|
batched_input = text
|
||||||
|
encodings = self.tokenizer.encode_batch(
|
||||||
|
batched_input,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
is_pretokenized=is_split_into_words,
|
||||||
|
)
|
||||||
|
tokens_and_encodings = [
|
||||||
|
self._convert_encoding(
|
||||||
|
encoding=encoding,
|
||||||
|
return_token_type_ids=False,
|
||||||
|
return_attention_mask=None,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
return_special_tokens_mask=False,
|
||||||
|
return_offsets_mapping=False,
|
||||||
|
return_length=False,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
for encoding in encodings
|
||||||
|
]
|
||||||
|
sanitized_tokens = {}
|
||||||
|
for key in tokens_and_encodings[0][0].keys():
|
||||||
|
stack = [e for item, _ in tokens_and_encodings for e in item[key]]
|
||||||
|
sanitized_tokens[key] = stack
|
||||||
|
return sanitized_tokens
|
||||||
|
|
||||||
|
def __call__(self, eqs):
|
||||||
|
topk = self.encode(eqs)
|
||||||
|
for k, p in zip(topk, [[self.bos_token_id, self.eos_token_id], [1, 1]]):
|
||||||
|
process_seq = [[p[0]] + x + [p[1]] for x in topk[k]]
|
||||||
|
max_length = 0
|
||||||
|
for seq in process_seq:
|
||||||
|
max_length = max(max_length, len(seq))
|
||||||
|
labels = np.zeros((len(process_seq), max_length), dtype="int64")
|
||||||
|
for idx, seq in enumerate(process_seq):
|
||||||
|
l = len(seq)
|
||||||
|
labels[idx][:l] = seq
|
||||||
|
topk[k] = labels
|
||||||
|
return (
|
||||||
|
np.array(topk["input_ids"]).astype(np.int64),
|
||||||
|
np.array(topk["attention_mask"]).astype(np.int64),
|
||||||
|
max_length,
|
||||||
|
)
|
||||||
|
|||||||
179
ppocr/data/imaug/latex_ocr_aug.py
Normal file
179
ppocr/data/imaug/latex_ocr_aug.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/dataset/transforms.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import math
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import albumentations as A
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class LatexTrainTransform:
|
||||||
|
def __init__(self, bitmap_prob=0.04, **kwargs):
|
||||||
|
# your init code
|
||||||
|
self.bitmap_prob = bitmap_prob
|
||||||
|
self.train_transform = A.Compose(
|
||||||
|
[
|
||||||
|
A.Compose(
|
||||||
|
[
|
||||||
|
A.ShiftScaleRotate(
|
||||||
|
shift_limit=0,
|
||||||
|
scale_limit=(-0.15, 0),
|
||||||
|
rotate_limit=1,
|
||||||
|
border_mode=0,
|
||||||
|
interpolation=3,
|
||||||
|
value=[255, 255, 255],
|
||||||
|
p=1,
|
||||||
|
),
|
||||||
|
A.GridDistortion(
|
||||||
|
distort_limit=0.1,
|
||||||
|
border_mode=0,
|
||||||
|
interpolation=3,
|
||||||
|
value=[255, 255, 255],
|
||||||
|
p=0.5,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
p=0.15,
|
||||||
|
),
|
||||||
|
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.3),
|
||||||
|
A.GaussNoise(10, p=0.2),
|
||||||
|
A.RandomBrightnessContrast(0.05, (-0.2, 0), True, p=0.2),
|
||||||
|
A.ImageCompression(95, p=0.3),
|
||||||
|
A.ToGray(always_apply=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data["image"]
|
||||||
|
if np.random.random() < self.bitmap_prob:
|
||||||
|
img[img != 255] = 0
|
||||||
|
img = self.train_transform(image=img)["image"]
|
||||||
|
data["image"] = img
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class LatexTestTransform:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# your init code
|
||||||
|
self.test_transform = A.Compose(
|
||||||
|
[
|
||||||
|
A.ToGray(always_apply=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data["image"]
|
||||||
|
img = self.test_transform(image=img)["image"]
|
||||||
|
data["image"] = img
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class MinMaxResize:
|
||||||
|
def __init__(self, min_dimensions=[32, 32], max_dimensions=[672, 192], **kwargs):
|
||||||
|
# your init code
|
||||||
|
self.min_dimensions = min_dimensions
|
||||||
|
self.max_dimensions = max_dimensions
|
||||||
|
# pass
|
||||||
|
|
||||||
|
def pad_(self, img, divable=32):
|
||||||
|
threshold = 128
|
||||||
|
data = np.array(img.convert("LA"))
|
||||||
|
if data[..., -1].var() == 0:
|
||||||
|
data = (data[..., 0]).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
data = (255 - data[..., -1]).astype(np.uint8)
|
||||||
|
data = (data - data.min()) / (data.max() - data.min()) * 255
|
||||||
|
if data.mean() > threshold:
|
||||||
|
# To invert the text to white
|
||||||
|
gray = 255 * (data < threshold).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
gray = 255 * (data > threshold).astype(np.uint8)
|
||||||
|
data = 255 - data
|
||||||
|
|
||||||
|
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
|
||||||
|
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
|
||||||
|
rect = data[b : b + h, a : a + w]
|
||||||
|
im = Image.fromarray(rect).convert("L")
|
||||||
|
dims = []
|
||||||
|
for x in [w, h]:
|
||||||
|
div, mod = divmod(x, divable)
|
||||||
|
dims.append(divable * (div + (1 if mod > 0 else 0)))
|
||||||
|
padded = Image.new("L", dims, 255)
|
||||||
|
padded.paste(im, (0, 0, im.size[0], im.size[1]))
|
||||||
|
return padded
|
||||||
|
|
||||||
|
def minmax_size_(self, img, max_dimensions, min_dimensions):
|
||||||
|
if max_dimensions is not None:
|
||||||
|
ratios = [a / b for a, b in zip(img.size, max_dimensions)]
|
||||||
|
if any([r > 1 for r in ratios]):
|
||||||
|
size = np.array(img.size) // max(ratios)
|
||||||
|
img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
|
||||||
|
if min_dimensions is not None:
|
||||||
|
# hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
|
||||||
|
padded_size = [
|
||||||
|
max(img_dim, min_dim)
|
||||||
|
for img_dim, min_dim in zip(img.size, min_dimensions)
|
||||||
|
]
|
||||||
|
if padded_size != list(img.size): # assert hypothesis
|
||||||
|
padded_im = Image.new("L", padded_size, 255)
|
||||||
|
padded_im.paste(img, img.getbbox())
|
||||||
|
img = padded_im
|
||||||
|
return img
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data["image"]
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
if (
|
||||||
|
self.min_dimensions[0] <= w <= self.max_dimensions[0]
|
||||||
|
and self.min_dimensions[1] <= h <= self.max_dimensions[1]
|
||||||
|
):
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
im = Image.fromarray(np.uint8(img))
|
||||||
|
im = self.minmax_size_(
|
||||||
|
self.pad_(im), self.max_dimensions, self.min_dimensions
|
||||||
|
)
|
||||||
|
im = np.array(im)
|
||||||
|
im = np.dstack((im, im, im))
|
||||||
|
data["image"] = im
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class LatexImageFormat:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# your init code
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data["image"]
|
||||||
|
im_h, im_w = img.shape[:2]
|
||||||
|
divide_h = math.ceil(im_h / 16) * 16
|
||||||
|
divide_w = math.ceil(im_w / 16) * 16
|
||||||
|
img = img[:, :, 0]
|
||||||
|
img = np.pad(
|
||||||
|
img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
|
||||||
|
)
|
||||||
|
img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
|
||||||
|
data["image"] = img_expanded
|
||||||
|
return data
|
||||||
172
ppocr/data/latexocr_dataset.py
Normal file
172
ppocr/data/latexocr_dataset.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/dataset/dataset.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import traceback
|
||||||
|
import paddle
|
||||||
|
from paddle.io import Dataset
|
||||||
|
from .imaug.label_ops import LatexOCRLabelEncode
|
||||||
|
from .imaug import transform, create_operators
|
||||||
|
|
||||||
|
|
||||||
|
class LaTeXOCRDataSet(Dataset):
|
||||||
|
def __init__(self, config, mode, logger, seed=None):
|
||||||
|
super(LaTeXOCRDataSet, self).__init__()
|
||||||
|
self.logger = logger
|
||||||
|
self.mode = mode.lower()
|
||||||
|
|
||||||
|
global_config = config["Global"]
|
||||||
|
dataset_config = config[mode]["dataset"]
|
||||||
|
loader_config = config[mode]["loader"]
|
||||||
|
|
||||||
|
pkl_path = dataset_config.pop("data")
|
||||||
|
self.min_dimensions = dataset_config.pop("min_dimensions")
|
||||||
|
self.max_dimensions = dataset_config.pop("max_dimensions")
|
||||||
|
self.batchsize = dataset_config.pop("batch_size_per_pair")
|
||||||
|
self.keep_smaller_batches = dataset_config.pop("keep_smaller_batches")
|
||||||
|
self.max_seq_len = global_config.pop("max_seq_len")
|
||||||
|
self.rec_char_dict_path = global_config.pop("rec_char_dict_path")
|
||||||
|
self.tokenizer = LatexOCRLabelEncode(self.rec_char_dict_path)
|
||||||
|
|
||||||
|
file = open(pkl_path, "rb")
|
||||||
|
data = pickle.load(file)
|
||||||
|
temp = {}
|
||||||
|
for k in data:
|
||||||
|
if (
|
||||||
|
self.min_dimensions[0] <= k[0] <= self.max_dimensions[0]
|
||||||
|
and self.min_dimensions[1] <= k[1] <= self.max_dimensions[1]
|
||||||
|
):
|
||||||
|
temp[k] = data[k]
|
||||||
|
self.data = temp
|
||||||
|
self.do_shuffle = loader_config["shuffle"]
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
if self.mode == "train" and self.do_shuffle:
|
||||||
|
random.seed(self.seed)
|
||||||
|
self.pairs = []
|
||||||
|
for k in self.data:
|
||||||
|
info = np.array(self.data[k], dtype=object)
|
||||||
|
p = (
|
||||||
|
paddle.randperm(len(info))
|
||||||
|
if self.mode == "train" and self.do_shuffle
|
||||||
|
else paddle.arange(len(info))
|
||||||
|
)
|
||||||
|
for i in range(0, len(info), self.batchsize):
|
||||||
|
batch = info[p[i : i + self.batchsize]]
|
||||||
|
if len(batch.shape) == 1:
|
||||||
|
batch = batch[None, :]
|
||||||
|
if len(batch) < self.batchsize and not self.keep_smaller_batches:
|
||||||
|
continue
|
||||||
|
self.pairs.append(batch)
|
||||||
|
if self.do_shuffle:
|
||||||
|
self.pairs = np.random.permutation(np.array(self.pairs, dtype=object))
|
||||||
|
else:
|
||||||
|
self.pairs = np.array(self.pairs, dtype=object)
|
||||||
|
|
||||||
|
self.size = len(self.pairs)
|
||||||
|
self.set_epoch_as_seed(self.seed, dataset_config)
|
||||||
|
|
||||||
|
self.ops = create_operators(dataset_config["transforms"], global_config)
|
||||||
|
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2)
|
||||||
|
self.need_reset = True
|
||||||
|
|
||||||
|
def set_epoch_as_seed(self, seed, dataset_config):
|
||||||
|
if self.mode == "train":
|
||||||
|
try:
|
||||||
|
border_map_id = [
|
||||||
|
index
|
||||||
|
for index, dictionary in enumerate(dataset_config["transforms"])
|
||||||
|
if "MakeBorderMap" in dictionary
|
||||||
|
][0]
|
||||||
|
shrink_map_id = [
|
||||||
|
index
|
||||||
|
for index, dictionary in enumerate(dataset_config["transforms"])
|
||||||
|
if "MakeShrinkMap" in dictionary
|
||||||
|
][0]
|
||||||
|
dataset_config["transforms"][border_map_id]["MakeBorderMap"][
|
||||||
|
"epoch"
|
||||||
|
] = (seed if seed is not None else 0)
|
||||||
|
dataset_config["transforms"][shrink_map_id]["MakeShrinkMap"][
|
||||||
|
"epoch"
|
||||||
|
] = (seed if seed is not None else 0)
|
||||||
|
except Exception as E:
|
||||||
|
print(E)
|
||||||
|
return
|
||||||
|
|
||||||
|
def shuffle_data_random(self):
|
||||||
|
random.seed(self.seed)
|
||||||
|
random.shuffle(self.data_lines)
|
||||||
|
return
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
batch = self.pairs[idx]
|
||||||
|
eqs, ims = batch.T
|
||||||
|
try:
|
||||||
|
max_width, max_height, max_length = 0, 0, 0
|
||||||
|
|
||||||
|
images_transform = []
|
||||||
|
|
||||||
|
for img_path in ims:
|
||||||
|
data = {
|
||||||
|
"img_path": img_path,
|
||||||
|
}
|
||||||
|
with open(data["img_path"], "rb") as f:
|
||||||
|
img = f.read()
|
||||||
|
data["image"] = img
|
||||||
|
item = transform(data, self.ops)
|
||||||
|
images_transform.append(np.array(item[0]))
|
||||||
|
image_concat = np.concatenate(images_transform, axis=0)[:, np.newaxis, :, :]
|
||||||
|
images_transform = image_concat.astype(np.float32)
|
||||||
|
labels, attention_mask, max_length = self.tokenizer(list(eqs))
|
||||||
|
if self.max_seq_len < max_length:
|
||||||
|
rnd_idx = (
|
||||||
|
np.random.randint(self.__len__())
|
||||||
|
if self.mode == "train"
|
||||||
|
else (idx + 1) % self.__len__()
|
||||||
|
)
|
||||||
|
return self.__getitem__(rnd_idx)
|
||||||
|
return (images_transform, labels, attention_mask)
|
||||||
|
|
||||||
|
except:
|
||||||
|
|
||||||
|
self.logger.error(
|
||||||
|
"When parsing line {}, error happened with msg: {}".format(
|
||||||
|
data["img_path"], traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
outs = None
|
||||||
|
|
||||||
|
if outs is None:
|
||||||
|
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
||||||
|
rnd_idx = (
|
||||||
|
np.random.randint(self.__len__())
|
||||||
|
if self.mode == "train"
|
||||||
|
else (idx + 1) % self.__len__()
|
||||||
|
)
|
||||||
|
return self.__getitem__(rnd_idx)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.size
|
||||||
@ -45,6 +45,7 @@ from .rec_satrn_loss import SATRNLoss
|
|||||||
from .rec_nrtr_loss import NRTRLoss
|
from .rec_nrtr_loss import NRTRLoss
|
||||||
from .rec_parseq_loss import ParseQLoss
|
from .rec_parseq_loss import ParseQLoss
|
||||||
from .rec_cppd_loss import CPPDLoss
|
from .rec_cppd_loss import CPPDLoss
|
||||||
|
from .rec_latexocr_loss import LaTeXOCRLoss
|
||||||
|
|
||||||
# cls loss
|
# cls loss
|
||||||
from .cls_loss import ClsLoss
|
from .cls_loss import ClsLoss
|
||||||
@ -107,6 +108,7 @@ def build_loss(config):
|
|||||||
"NRTRLoss",
|
"NRTRLoss",
|
||||||
"ParseQLoss",
|
"ParseQLoss",
|
||||||
"CPPDLoss",
|
"CPPDLoss",
|
||||||
|
"LaTeXOCRLoss",
|
||||||
]
|
]
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop("name")
|
module_name = config.pop("name")
|
||||||
|
|||||||
47
ppocr/losses/rec_latexocr_loss.py
Normal file
47
ppocr/losses/rec_latexocr_loss.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/lucidrains/x-transformers/blob/main/x_transformers/autoregressive_wrapper.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class LaTeXOCRLoss(nn.Layer):
|
||||||
|
"""
|
||||||
|
LaTeXOCR adopt CrossEntropyLoss for network training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(LaTeXOCRLoss, self).__init__()
|
||||||
|
self.ignore_index = -100
|
||||||
|
self.cross = nn.CrossEntropyLoss(
|
||||||
|
reduction="mean", ignore_index=self.ignore_index
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, preds, batch):
|
||||||
|
word_probs = preds
|
||||||
|
labels = batch[1][:, 1:]
|
||||||
|
word_loss = self.cross(
|
||||||
|
paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
|
||||||
|
paddle.reshape(labels, [-1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = word_loss
|
||||||
|
return {"loss": loss}
|
||||||
@ -22,7 +22,7 @@ import copy
|
|||||||
__all__ = ["build_metric"]
|
__all__ = ["build_metric"]
|
||||||
|
|
||||||
from .det_metric import DetMetric, DetFCEMetric
|
from .det_metric import DetMetric, DetFCEMetric
|
||||||
from .rec_metric import RecMetric, CNTMetric, CANMetric
|
from .rec_metric import RecMetric, CNTMetric, CANMetric, LaTeXOCRMetric
|
||||||
from .cls_metric import ClsMetric
|
from .cls_metric import ClsMetric
|
||||||
from .e2e_metric import E2EMetric
|
from .e2e_metric import E2EMetric
|
||||||
from .distillation_metric import DistillationMetric
|
from .distillation_metric import DistillationMetric
|
||||||
@ -50,6 +50,7 @@ def build_metric(config):
|
|||||||
"CTMetric",
|
"CTMetric",
|
||||||
"CNTMetric",
|
"CNTMetric",
|
||||||
"CANMetric",
|
"CANMetric",
|
||||||
|
"LaTeXOCRMetric",
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|||||||
240
ppocr/metrics/bleu.py
Normal file
240
ppocr/metrics/bleu.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import math
|
||||||
|
import collections
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ngrams(segment, max_order):
|
||||||
|
"""Extracts all n-grams upto a given maximum order from an input segment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment: text segment from which n-grams will be extracted.
|
||||||
|
max_order: maximum length in tokens of the n-grams returned by this
|
||||||
|
methods.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Counter containing all n-grams upto max_order in segment
|
||||||
|
with a count of how many times each n-gram occurred.
|
||||||
|
"""
|
||||||
|
ngram_counts = collections.Counter()
|
||||||
|
for order in range(1, max_order + 1):
|
||||||
|
for i in range(0, len(segment) - order + 1):
|
||||||
|
ngram = tuple(segment[i : i + order])
|
||||||
|
ngram_counts[ngram] += 1
|
||||||
|
return ngram_counts
|
||||||
|
|
||||||
|
|
||||||
|
def compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False):
|
||||||
|
"""Computes BLEU score of translated segments against one or more references.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reference_corpus: list of lists of references for each translation. Each
|
||||||
|
reference should be tokenized into a list of tokens.
|
||||||
|
translation_corpus: list of translations to score. Each translation
|
||||||
|
should be tokenized into a list of tokens.
|
||||||
|
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||||
|
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
|
||||||
|
precisions and brevity penalty.
|
||||||
|
"""
|
||||||
|
matches_by_order = [0] * max_order
|
||||||
|
possible_matches_by_order = [0] * max_order
|
||||||
|
reference_length = 0
|
||||||
|
translation_length = 0
|
||||||
|
for references, translation in zip(reference_corpus, translation_corpus):
|
||||||
|
reference_length += min(len(r) for r in references)
|
||||||
|
translation_length += len(translation)
|
||||||
|
|
||||||
|
merged_ref_ngram_counts = collections.Counter()
|
||||||
|
for reference in references:
|
||||||
|
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
|
||||||
|
translation_ngram_counts = _get_ngrams(translation, max_order)
|
||||||
|
overlap = translation_ngram_counts & merged_ref_ngram_counts
|
||||||
|
for ngram in overlap:
|
||||||
|
matches_by_order[len(ngram) - 1] += overlap[ngram]
|
||||||
|
for order in range(1, max_order + 1):
|
||||||
|
possible_matches = len(translation) - order + 1
|
||||||
|
if possible_matches > 0:
|
||||||
|
possible_matches_by_order[order - 1] += possible_matches
|
||||||
|
|
||||||
|
precisions = [0] * max_order
|
||||||
|
for i in range(0, max_order):
|
||||||
|
if smooth:
|
||||||
|
precisions[i] = (matches_by_order[i] + 1.0) / (
|
||||||
|
possible_matches_by_order[i] + 1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if possible_matches_by_order[i] > 0:
|
||||||
|
precisions[i] = (
|
||||||
|
float(matches_by_order[i]) / possible_matches_by_order[i]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
precisions[i] = 0.0
|
||||||
|
|
||||||
|
if min(precisions) > 0:
|
||||||
|
p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions)
|
||||||
|
geo_mean = math.exp(p_log_sum)
|
||||||
|
else:
|
||||||
|
geo_mean = 0
|
||||||
|
|
||||||
|
ratio = float(translation_length) / reference_length
|
||||||
|
|
||||||
|
if ratio > 1.0:
|
||||||
|
bp = 1.0
|
||||||
|
else:
|
||||||
|
bp = math.exp(1 - 1.0 / ratio)
|
||||||
|
|
||||||
|
bleu = geo_mean * bp
|
||||||
|
|
||||||
|
return (bleu, precisions, bp, ratio, translation_length, reference_length)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTokenizer:
|
||||||
|
"""A base dummy tokenizer to derive from."""
|
||||||
|
|
||||||
|
def signature(self):
|
||||||
|
"""
|
||||||
|
Returns a signature for the tokenizer.
|
||||||
|
:return: signature string
|
||||||
|
"""
|
||||||
|
return "none"
|
||||||
|
|
||||||
|
def __call__(self, line):
|
||||||
|
"""
|
||||||
|
Tokenizes an input line with the tokenizer.
|
||||||
|
:param line: a segment to tokenize
|
||||||
|
:return: the tokenized line
|
||||||
|
"""
|
||||||
|
return line
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerRegexp(BaseTokenizer):
|
||||||
|
def signature(self):
|
||||||
|
return "re"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._re = [
|
||||||
|
# language-dependent part (assuming Western languages)
|
||||||
|
(re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "),
|
||||||
|
# tokenize period and comma unless preceded by a digit
|
||||||
|
(re.compile(r"([^0-9])([\.,])"), r"\1 \2 "),
|
||||||
|
# tokenize period and comma unless followed by a digit
|
||||||
|
(re.compile(r"([\.,])([^0-9])"), r" \1 \2"),
|
||||||
|
# tokenize dash when preceded by a digit
|
||||||
|
(re.compile(r"([0-9])(-)"), r"\1 \2 "),
|
||||||
|
# one space only between words
|
||||||
|
# NOTE: Doing this in Python (below) is faster
|
||||||
|
# (re.compile(r'\s+'), r' '),
|
||||||
|
]
|
||||||
|
|
||||||
|
@lru_cache(maxsize=2**16)
|
||||||
|
def __call__(self, line):
|
||||||
|
"""Common post-processing tokenizer for `13a` and `zh` tokenizers.
|
||||||
|
:param line: a segment to tokenize
|
||||||
|
:return: the tokenized line
|
||||||
|
"""
|
||||||
|
for _re, repl in self._re:
|
||||||
|
line = _re.sub(repl, line)
|
||||||
|
|
||||||
|
# no leading or trailing spaces, single space within words
|
||||||
|
# return ' '.join(line.split())
|
||||||
|
# This line is changed with regards to the original tokenizer (seen above) to return individual words
|
||||||
|
return line.split()
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer13a(BaseTokenizer):
|
||||||
|
def signature(self):
|
||||||
|
return "13a"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._post_tokenizer = TokenizerRegexp()
|
||||||
|
|
||||||
|
@lru_cache(maxsize=2**16)
|
||||||
|
def __call__(self, line):
|
||||||
|
"""Tokenizes an input line using a relatively minimal tokenization
|
||||||
|
that is however equivalent to mteval-v13a, used by WMT.
|
||||||
|
|
||||||
|
:param line: a segment to tokenize
|
||||||
|
:return: the tokenized line
|
||||||
|
"""
|
||||||
|
|
||||||
|
# language-independent part:
|
||||||
|
line = line.replace("<skipped>", "")
|
||||||
|
line = line.replace("-\n", "")
|
||||||
|
line = line.replace("\n", " ")
|
||||||
|
|
||||||
|
if "&" in line:
|
||||||
|
line = line.replace(""", '"')
|
||||||
|
line = line.replace("&", "&")
|
||||||
|
line = line.replace("<", "<")
|
||||||
|
line = line.replace(">", ">")
|
||||||
|
|
||||||
|
return self._post_tokenizer(f" {line} ")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_blue_score(
|
||||||
|
predictions, references, tokenizer=Tokenizer13a(), max_order=4, smooth=False
|
||||||
|
):
|
||||||
|
# if only one reference is provided make sure we still use list of lists
|
||||||
|
if isinstance(references[0], str):
|
||||||
|
references = [[ref] for ref in references]
|
||||||
|
|
||||||
|
references = [[tokenizer(r) for r in ref] for ref in references]
|
||||||
|
predictions = [tokenizer(p) for p in predictions]
|
||||||
|
score = compute_bleu(
|
||||||
|
reference_corpus=references,
|
||||||
|
translation_corpus=predictions,
|
||||||
|
max_order=max_order,
|
||||||
|
smooth=smooth,
|
||||||
|
)
|
||||||
|
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
|
||||||
|
return bleu
|
||||||
|
|
||||||
|
|
||||||
|
def cal_distance(word1, word2):
|
||||||
|
m = len(word1)
|
||||||
|
n = len(word2)
|
||||||
|
if m * n == 0:
|
||||||
|
return m + n
|
||||||
|
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
||||||
|
for i in range(m + 1):
|
||||||
|
dp[i][0] = i
|
||||||
|
for j in range(n + 1):
|
||||||
|
dp[0][j] = j
|
||||||
|
for i in range(1, m + 1):
|
||||||
|
for j in range(1, n + 1):
|
||||||
|
a = dp[i - 1][j] + 1
|
||||||
|
b = dp[i][j - 1] + 1
|
||||||
|
c = dp[i - 1][j - 1]
|
||||||
|
if word1[i - 1] != word2[j - 1]:
|
||||||
|
c += 1
|
||||||
|
dp[i][j] = min(a, b, c)
|
||||||
|
return dp[m][n]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_edit_distance(prediction, label):
|
||||||
|
prediction = prediction.strip().split(" ")
|
||||||
|
label = label.strip().split(" ")
|
||||||
|
distance = cal_distance(prediction, label)
|
||||||
|
return distance
|
||||||
@ -17,6 +17,7 @@ from difflib import SequenceMatcher
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import string
|
import string
|
||||||
|
from .bleu import compute_blue_score, compute_edit_distance
|
||||||
|
|
||||||
|
|
||||||
class RecMetric(object):
|
class RecMetric(object):
|
||||||
@ -177,3 +178,121 @@ class CANMetric(object):
|
|||||||
self.exp_right = []
|
self.exp_right = []
|
||||||
self.word_total_length = 0
|
self.word_total_length = 0
|
||||||
self.exp_total_num = 0
|
self.exp_total_num = 0
|
||||||
|
|
||||||
|
|
||||||
|
class LaTeXOCRMetric(object):
|
||||||
|
def __init__(self, main_indicator="exp_rate", cal_blue_score=False, **kwargs):
|
||||||
|
self.main_indicator = main_indicator
|
||||||
|
self.cal_blue_score = cal_blue_score
|
||||||
|
self.edit_right = []
|
||||||
|
self.exp_right = []
|
||||||
|
self.blue_right = []
|
||||||
|
self.e1_right = []
|
||||||
|
self.e2_right = []
|
||||||
|
self.e3_right = []
|
||||||
|
self.editdistance_total_length = 0
|
||||||
|
self.exp_total_num = 0
|
||||||
|
self.edit_dist = 0
|
||||||
|
self.exp_rate = 0
|
||||||
|
if self.cal_blue_score:
|
||||||
|
self.blue_score = 0
|
||||||
|
self.e1 = 0
|
||||||
|
self.e2 = 0
|
||||||
|
self.e3 = 0
|
||||||
|
self.reset()
|
||||||
|
self.epoch_reset()
|
||||||
|
|
||||||
|
def __call__(self, preds, batch, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
epoch_reset = v
|
||||||
|
if epoch_reset:
|
||||||
|
self.epoch_reset()
|
||||||
|
word_pred = preds
|
||||||
|
word_label = batch
|
||||||
|
line_right, e1, e2, e3 = 0, 0, 0, 0
|
||||||
|
lev_dist = []
|
||||||
|
for labels, prediction in zip(word_label, word_pred):
|
||||||
|
if prediction == labels:
|
||||||
|
line_right += 1
|
||||||
|
distance = compute_edit_distance(prediction, labels)
|
||||||
|
lev_dist.append(Levenshtein.normalized_distance(prediction, labels))
|
||||||
|
if distance <= 1:
|
||||||
|
e1 += 1
|
||||||
|
if distance <= 2:
|
||||||
|
e2 += 1
|
||||||
|
if distance <= 3:
|
||||||
|
e3 += 1
|
||||||
|
|
||||||
|
batch_size = len(lev_dist)
|
||||||
|
|
||||||
|
self.edit_dist = sum(lev_dist) # float
|
||||||
|
self.exp_rate = line_right # float
|
||||||
|
if self.cal_blue_score:
|
||||||
|
self.blue_score = compute_blue_score(word_pred, word_label)
|
||||||
|
self.e1 = e1
|
||||||
|
self.e2 = e2
|
||||||
|
self.e3 = e3
|
||||||
|
exp_length = len(word_label)
|
||||||
|
self.edit_right.append(self.edit_dist)
|
||||||
|
self.exp_right.append(self.exp_rate)
|
||||||
|
if self.cal_blue_score:
|
||||||
|
self.blue_right.append(self.blue_score * batch_size)
|
||||||
|
self.e1_right.append(self.e1)
|
||||||
|
self.e2_right.append(self.e2)
|
||||||
|
self.e3_right.append(self.e3)
|
||||||
|
self.editdistance_total_length = self.editdistance_total_length + exp_length
|
||||||
|
self.exp_total_num = self.exp_total_num + exp_length
|
||||||
|
|
||||||
|
def get_metric(self):
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'edit distance': 0,
|
||||||
|
"blue_score": 0,
|
||||||
|
"exp_rate": 0,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
cur_edit_distance = sum(self.edit_right) / self.exp_total_num
|
||||||
|
cur_exp_rate = sum(self.exp_right) / self.exp_total_num
|
||||||
|
if self.cal_blue_score:
|
||||||
|
cur_blue_score = sum(self.blue_right) / self.editdistance_total_length
|
||||||
|
cur_exp_1 = sum(self.e1_right) / self.exp_total_num
|
||||||
|
cur_exp_2 = sum(self.e2_right) / self.exp_total_num
|
||||||
|
cur_exp_3 = sum(self.e3_right) / self.exp_total_num
|
||||||
|
self.reset()
|
||||||
|
if self.cal_blue_score:
|
||||||
|
return {
|
||||||
|
"blue_score ": cur_blue_score,
|
||||||
|
"edit distance ": cur_edit_distance,
|
||||||
|
"exp_rate ": cur_exp_rate,
|
||||||
|
"exp_rate<=1 ": cur_exp_1,
|
||||||
|
"exp_rate<=2 ": cur_exp_2,
|
||||||
|
"exp_rate<=3 ": cur_exp_3,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"edit distance": cur_edit_distance,
|
||||||
|
"exp_rate": cur_exp_rate,
|
||||||
|
"exp_rate<=1 ": cur_exp_1,
|
||||||
|
"exp_rate<=2 ": cur_exp_2,
|
||||||
|
"exp_rate<=3 ": cur_exp_3,
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.edit_dist = 0
|
||||||
|
self.exp_rate = 0
|
||||||
|
if self.cal_blue_score:
|
||||||
|
self.blue_score = 0
|
||||||
|
self.e1 = 0
|
||||||
|
self.e2 = 0
|
||||||
|
self.e3 = 0
|
||||||
|
|
||||||
|
def epoch_reset(self):
|
||||||
|
self.edit_right = []
|
||||||
|
self.exp_right = []
|
||||||
|
if self.cal_blue_score:
|
||||||
|
self.blue_right = []
|
||||||
|
self.e1_right = []
|
||||||
|
self.e2_right = []
|
||||||
|
self.e3_right = []
|
||||||
|
self.editdistance_total_length = 0
|
||||||
|
self.exp_total_num = 0
|
||||||
|
|||||||
@ -59,6 +59,8 @@ def build_backbone(config, model_type):
|
|||||||
from .rec_vitstr import ViTSTR
|
from .rec_vitstr import ViTSTR
|
||||||
from .rec_resnet_rfl import ResNetRFL
|
from .rec_resnet_rfl import ResNetRFL
|
||||||
from .rec_densenet import DenseNet
|
from .rec_densenet import DenseNet
|
||||||
|
from .rec_resnetv2 import ResNetV2
|
||||||
|
from .rec_hybridvit import HybridTransformer
|
||||||
from .rec_shallow_cnn import ShallowCNN
|
from .rec_shallow_cnn import ShallowCNN
|
||||||
from .rec_lcnetv3 import PPLCNetV3
|
from .rec_lcnetv3 import PPLCNetV3
|
||||||
from .rec_hgnet import PPHGNet_small
|
from .rec_hgnet import PPHGNet_small
|
||||||
@ -89,6 +91,8 @@ def build_backbone(config, model_type):
|
|||||||
"ViT",
|
"ViT",
|
||||||
"RepSVTR",
|
"RepSVTR",
|
||||||
"SVTRv2",
|
"SVTRv2",
|
||||||
|
"ResNetV2",
|
||||||
|
"HybridTransformer",
|
||||||
]
|
]
|
||||||
elif model_type == "e2e":
|
elif model_type == "e2e":
|
||||||
from .e2e_resnet_vd_pg import ResNet
|
from .e2e_resnet_vd_pg import ResNet
|
||||||
|
|||||||
529
ppocr/modeling/backbones/rec_hybridvit.py
Normal file
529
ppocr/modeling/backbones/rec_hybridvit.py
Normal file
@ -0,0 +1,529 @@
|
|||||||
|
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from itertools import repeat
|
||||||
|
import collections
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from ppocr.modeling.backbones.rec_resnetv2 import (
|
||||||
|
ResNetV2,
|
||||||
|
StdConv2dSame,
|
||||||
|
DropPath,
|
||||||
|
get_padding,
|
||||||
|
)
|
||||||
|
from paddle.nn.initializer import (
|
||||||
|
TruncatedNormal,
|
||||||
|
Constant,
|
||||||
|
Normal,
|
||||||
|
KaimingUniform,
|
||||||
|
XavierUniform,
|
||||||
|
)
|
||||||
|
|
||||||
|
normal_ = Normal(mean=0.0, std=1e-6)
|
||||||
|
zeros_ = Constant(value=0.0)
|
||||||
|
ones_ = Constant(value=1.0)
|
||||||
|
kaiming_normal_ = KaimingUniform(nonlinearity="relu")
|
||||||
|
trunc_normal_ = TruncatedNormal(std=0.02)
|
||||||
|
xavier_uniform_ = XavierUniform()
|
||||||
|
|
||||||
|
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
return x
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
to_1tuple = _ntuple(1)
|
||||||
|
to_2tuple = _ntuple(2)
|
||||||
|
to_3tuple = _ntuple(3)
|
||||||
|
to_4tuple = _ntuple(4)
|
||||||
|
to_ntuple = _ntuple
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dAlign(nn.Conv2D):
|
||||||
|
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
|
||||||
|
|
||||||
|
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
|
||||||
|
https://arxiv.org/abs/1903.10520v2
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channel,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
eps=1e-6,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_channel,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias_attr=bias,
|
||||||
|
weight_attr=True,
|
||||||
|
)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.conv2d(
|
||||||
|
x,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
self._stride,
|
||||||
|
self._padding,
|
||||||
|
self._dilation,
|
||||||
|
self._groups,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HybridEmbed(nn.Layer):
|
||||||
|
"""CNN Feature Map Embedding
|
||||||
|
Extract feature map from CNN, flatten, project to embedding dim.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=1,
|
||||||
|
feature_size=None,
|
||||||
|
in_chans=3,
|
||||||
|
embed_dim=768,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(backbone, nn.Layer)
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
self.img_size = img_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.backbone = backbone
|
||||||
|
feature_dim = 1024
|
||||||
|
feature_size = (42, 12)
|
||||||
|
patch_size = (1, 1)
|
||||||
|
assert (
|
||||||
|
feature_size[0] % patch_size[0] == 0
|
||||||
|
and feature_size[1] % patch_size[1] == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.grid_size = (
|
||||||
|
feature_size[0] // patch_size[0],
|
||||||
|
feature_size[1] // patch_size[1],
|
||||||
|
)
|
||||||
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||||
|
self.proj = nn.Conv2D(
|
||||||
|
feature_dim,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
weight_attr=True,
|
||||||
|
bias_attr=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x = self.backbone(x)
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||||
|
x = self.proj(x).flatten(2).transpose([0, 2, 1])
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class myLinear(nn.Linear):
|
||||||
|
def __init__(self, in_channel, out_channels, weight_attr=True, bias_attr=True):
|
||||||
|
super().__init__(
|
||||||
|
in_channel, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return paddle.matmul(x, self.weight, transpose_y=True) + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Layer):
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = myLinear(dim, dim, weight_attr=True, bias_attr=True)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = (
|
||||||
|
self.qkv(x)
|
||||||
|
.reshape([B, N, 3, self.num_heads, C // self.num_heads])
|
||||||
|
.transpose([2, 0, 3, 1, 4])
|
||||||
|
)
|
||||||
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
|
||||||
|
|
||||||
|
attn = F.softmax(attn, axis=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose([0, 2, 1, 3]).reshape([B, N, C])
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Layer):
|
||||||
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features,
|
||||||
|
hidden_features=None,
|
||||||
|
out_features=None,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
drop=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
drop_probs = to_2tuple(drop)
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop1(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Layer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
qkv_bias=False,
|
||||||
|
drop=0.0,
|
||||||
|
attn_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
)
|
||||||
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HybridTransformer(nn.Layer):
|
||||||
|
"""Implementation of HybridTransformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: input images with shape [N, 1, H, W]
|
||||||
|
label: LaTeX-OCR labels with shape [N, L] , L is the max sequence length
|
||||||
|
attention_mask: LaTeX-OCR attention mask with shape [N, L] , L is the max sequence length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The encoded features with shape [N, 1, H//16, W//16]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone_layers=[2, 3, 7],
|
||||||
|
input_channel=1,
|
||||||
|
is_predict=False,
|
||||||
|
is_export=False,
|
||||||
|
img_size=(224, 224),
|
||||||
|
patch_size=16,
|
||||||
|
num_classes=1000,
|
||||||
|
embed_dim=768,
|
||||||
|
depth=12,
|
||||||
|
num_heads=12,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
qkv_bias=True,
|
||||||
|
representation_size=None,
|
||||||
|
distilled=False,
|
||||||
|
drop_rate=0.0,
|
||||||
|
attn_drop_rate=0.0,
|
||||||
|
drop_path_rate=0.0,
|
||||||
|
embed_layer=None,
|
||||||
|
norm_layer=None,
|
||||||
|
act_layer=None,
|
||||||
|
weight_init="",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(HybridTransformer, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_features = self.embed_dim = (
|
||||||
|
embed_dim # num_features for consistency with other models
|
||||||
|
)
|
||||||
|
self.num_tokens = 2 if distilled else 1
|
||||||
|
norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6)
|
||||||
|
act_layer = act_layer or nn.GELU
|
||||||
|
self.height, self.width = img_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
backbone = ResNetV2(
|
||||||
|
layers=backbone_layers,
|
||||||
|
num_classes=0,
|
||||||
|
global_pool="",
|
||||||
|
in_chans=input_channel,
|
||||||
|
preact=False,
|
||||||
|
stem_type="same",
|
||||||
|
conv_layer=StdConv2dSame,
|
||||||
|
is_export=is_export,
|
||||||
|
)
|
||||||
|
min_patch_size = 2 ** (len(backbone_layers) + 1)
|
||||||
|
self.patch_embed = HybridEmbed(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size // min_patch_size,
|
||||||
|
in_chans=input_channel,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
backbone=backbone,
|
||||||
|
)
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
|
self.cls_token = paddle.create_parameter([1, 1, embed_dim], dtype="float32")
|
||||||
|
self.dist_token = (
|
||||||
|
paddle.create_parameter(
|
||||||
|
[1, 1, embed_dim],
|
||||||
|
dtype="float32",
|
||||||
|
)
|
||||||
|
if distilled
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.pos_embed = paddle.create_parameter(
|
||||||
|
[1, num_patches + self.num_tokens, embed_dim], dtype="float32"
|
||||||
|
)
|
||||||
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
zeros_(self.cls_token)
|
||||||
|
if self.dist_token is not None:
|
||||||
|
zeros_(self.dist_token)
|
||||||
|
zeros_(self.pos_embed)
|
||||||
|
|
||||||
|
dpr = [
|
||||||
|
x.item() for x in paddle.linspace(0, drop_path_rate, depth)
|
||||||
|
] # stochastic depth decay rule
|
||||||
|
self.blocks = nn.Sequential(
|
||||||
|
*[
|
||||||
|
Block(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr[i],
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act_layer=act_layer,
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = norm_layer(embed_dim)
|
||||||
|
|
||||||
|
# Representation layer
|
||||||
|
if representation_size and not distilled:
|
||||||
|
self.num_features = representation_size
|
||||||
|
self.pre_logits = nn.Sequential(
|
||||||
|
("fc", nn.Linear(embed_dim, representation_size)), ("act", nn.Tanh())
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.pre_logits = nn.Identity()
|
||||||
|
|
||||||
|
# Classifier head(s)
|
||||||
|
self.head = (
|
||||||
|
nn.Linear(self.num_features, num_classes)
|
||||||
|
if num_classes > 0
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.head_dist = None
|
||||||
|
if distilled:
|
||||||
|
self.head_dist = (
|
||||||
|
nn.Linear(self.embed_dim, self.num_classes)
|
||||||
|
if num_classes > 0
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.init_weights(weight_init)
|
||||||
|
self.out_channels = embed_dim
|
||||||
|
self.is_predict = is_predict
|
||||||
|
self.is_export = is_export
|
||||||
|
|
||||||
|
def init_weights(self, mode=""):
|
||||||
|
assert mode in ("jax", "jax_nlhb", "nlhb", "")
|
||||||
|
head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
||||||
|
trunc_normal_(self.pos_embed)
|
||||||
|
trunc_normal_(self.cls_token)
|
||||||
|
self.apply(_init_vit_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
# this fn left here for compat with downstream users
|
||||||
|
_init_vit_weights(m)
|
||||||
|
|
||||||
|
def load_pretrained(self, checkpoint_path, prefix=""):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def no_weight_decay(self):
|
||||||
|
return {"pos_embed", "cls_token", "dist_token"}
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
if self.dist_token is None:
|
||||||
|
return self.head
|
||||||
|
else:
|
||||||
|
return self.head, self.head_dist
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool=""):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.head = (
|
||||||
|
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
)
|
||||||
|
if self.num_tokens == 2:
|
||||||
|
self.head_dist = (
|
||||||
|
nn.Linear(self.embed_dim, self.num_classes)
|
||||||
|
if num_classes > 0
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
B, c, h, w = x.shape
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
cls_tokens = self.cls_token.expand(
|
||||||
|
[B, -1, -1]
|
||||||
|
) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = paddle.concat((cls_tokens, x), axis=1)
|
||||||
|
h, w = h // self.patch_size, w // self.patch_size
|
||||||
|
repeat_tensor = (
|
||||||
|
paddle.arange(h) * (self.width // self.patch_size - w)
|
||||||
|
).reshape([-1, 1])
|
||||||
|
repeat_tensor = paddle.repeat_interleave(
|
||||||
|
repeat_tensor, paddle.to_tensor(w), axis=1
|
||||||
|
).reshape([-1])
|
||||||
|
pos_emb_ind = repeat_tensor + paddle.arange(h * w)
|
||||||
|
pos_emb_ind = paddle.concat(
|
||||||
|
(paddle.zeros([1], dtype="int64"), pos_emb_ind + 1), axis=0
|
||||||
|
).cast(paddle.int64)
|
||||||
|
x += self.pos_embed[:, pos_emb_ind]
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x)
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, input_data):
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
x, label, attention_mask = input_data
|
||||||
|
else:
|
||||||
|
if isinstance(input_data, list):
|
||||||
|
x = input_data[0]
|
||||||
|
else:
|
||||||
|
x = input_data
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.head(x)
|
||||||
|
if self.training:
|
||||||
|
return x, label, attention_mask
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _init_vit_weights(
|
||||||
|
module: nn.Layer, name: str = "", head_bias: float = 0.0, jax_impl: bool = False
|
||||||
|
):
|
||||||
|
"""ViT weight initialization
|
||||||
|
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
||||||
|
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
||||||
|
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
||||||
|
"""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
if name.startswith("head"):
|
||||||
|
zeros_(module.weight)
|
||||||
|
constant_ = Constant(value=head_bias)
|
||||||
|
constant_(module.bias, head_bias)
|
||||||
|
elif name.startswith("pre_logits"):
|
||||||
|
zeros_(module.bias)
|
||||||
|
else:
|
||||||
|
if jax_impl:
|
||||||
|
xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
if "mlp" in name:
|
||||||
|
normal_(module.bias)
|
||||||
|
else:
|
||||||
|
zeros_(module.bias)
|
||||||
|
else:
|
||||||
|
trunc_normal_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
zeros_(module.bias)
|
||||||
|
elif jax_impl and isinstance(module, nn.Conv2D):
|
||||||
|
# NOTE conv was left to pytorch default in my original init
|
||||||
|
if module.bias is not None:
|
||||||
|
zeros_(module.bias)
|
||||||
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2D)):
|
||||||
|
zeros_(module.bias)
|
||||||
|
ones_(module.weight)
|
||||||
1283
ppocr/modeling/backbones/rec_resnetv2.py
Normal file
1283
ppocr/modeling/backbones/rec_resnetv2.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -40,6 +40,7 @@ def build_head(config):
|
|||||||
from .rec_visionlan_head import VLHead
|
from .rec_visionlan_head import VLHead
|
||||||
from .rec_rfl_head import RFLHead
|
from .rec_rfl_head import RFLHead
|
||||||
from .rec_can_head import CANHead
|
from .rec_can_head import CANHead
|
||||||
|
from .rec_latexocr_head import LaTeXOCRHead
|
||||||
from .rec_satrn_head import SATRNHead
|
from .rec_satrn_head import SATRNHead
|
||||||
from .rec_parseq_head import ParseQHead
|
from .rec_parseq_head import ParseQHead
|
||||||
from .rec_cppd_head import CPPDHead
|
from .rec_cppd_head import CPPDHead
|
||||||
@ -81,6 +82,7 @@ def build_head(config):
|
|||||||
"RFLHead",
|
"RFLHead",
|
||||||
"DRRGHead",
|
"DRRGHead",
|
||||||
"CANHead",
|
"CANHead",
|
||||||
|
"LaTeXOCRHead",
|
||||||
"SATRNHead",
|
"SATRNHead",
|
||||||
"PFHeadLocal",
|
"PFHeadLocal",
|
||||||
"ParseQHead",
|
"ParseQHead",
|
||||||
|
|||||||
1027
ppocr/modeling/heads/rec_latexocr_head.py
Normal file
1027
ppocr/modeling/heads/rec_latexocr_head.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -42,6 +42,7 @@ from .rec_postprocess import (
|
|||||||
SATRNLabelDecode,
|
SATRNLabelDecode,
|
||||||
ParseQLabelDecode,
|
ParseQLabelDecode,
|
||||||
CPPDLabelDecode,
|
CPPDLabelDecode,
|
||||||
|
LaTeXOCRDecode,
|
||||||
)
|
)
|
||||||
from .cls_postprocess import ClsPostProcess
|
from .cls_postprocess import ClsPostProcess
|
||||||
from .pg_postprocess import PGPostProcess
|
from .pg_postprocess import PGPostProcess
|
||||||
@ -96,6 +97,7 @@ def build_post_process(config, global_config=None):
|
|||||||
"SATRNLabelDecode",
|
"SATRNLabelDecode",
|
||||||
"ParseQLabelDecode",
|
"ParseQLabelDecode",
|
||||||
"CPPDLabelDecode",
|
"CPPDLabelDecode",
|
||||||
|
"LaTeXOCRDecode",
|
||||||
]
|
]
|
||||||
|
|
||||||
if config["name"] == "PSEPostProcess":
|
if config["name"] == "PSEPostProcess":
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
from tokenizers import Tokenizer as TokenizerFast
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
@ -1210,3 +1211,53 @@ class CPPDLabelDecode(NRTRLabelDecode):
|
|||||||
def add_special_char(self, dict_character):
|
def add_special_char(self, dict_character):
|
||||||
dict_character = ["</s>"] + dict_character
|
dict_character = ["</s>"] + dict_character
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
|
|
||||||
|
class LaTeXOCRDecode(object):
|
||||||
|
"""Convert between latex-symbol and symbol-index"""
|
||||||
|
|
||||||
|
def __init__(self, rec_char_dict_path, **kwargs):
|
||||||
|
super(LaTeXOCRDecode, self).__init__()
|
||||||
|
self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
|
||||||
|
|
||||||
|
def post_process(self, s):
|
||||||
|
text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
|
||||||
|
letter = "[a-zA-Z]"
|
||||||
|
noletter = "[\W_^\d]"
|
||||||
|
names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
|
||||||
|
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
||||||
|
news = s
|
||||||
|
while True:
|
||||||
|
s = news
|
||||||
|
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
|
||||||
|
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
|
||||||
|
news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
|
||||||
|
if news == s:
|
||||||
|
break
|
||||||
|
return s
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
if len(tokens.shape) == 1:
|
||||||
|
tokens = tokens[None, :]
|
||||||
|
dec = [self.tokenizer.decode(tok) for tok in tokens]
|
||||||
|
dec_str_list = [
|
||||||
|
"".join(detok.split(" "))
|
||||||
|
.replace("Ġ", " ")
|
||||||
|
.replace("[EOS]", "")
|
||||||
|
.replace("[BOS]", "")
|
||||||
|
.replace("[PAD]", "")
|
||||||
|
.strip()
|
||||||
|
for detok in dec
|
||||||
|
]
|
||||||
|
return [self.post_process(dec_str) for dec_str in dec_str_list]
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, mode="eval", *args, **kwargs):
|
||||||
|
if mode == "train":
|
||||||
|
preds_idx = np.array(preds.argmax(axis=2))
|
||||||
|
text = self.decode(preds_idx)
|
||||||
|
else:
|
||||||
|
text = self.decode(np.array(preds))
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(np.array(label))
|
||||||
|
return text, label
|
||||||
|
|||||||
1
ppocr/utils/dict/latex_ocr_tokenizer.json
Normal file
1
ppocr/utils/dict/latex_ocr_tokenizer.json
Normal file
File diff suppressed because one or more lines are too long
70
ppocr/utils/formula_utils/math_txt2pkl.py
Normal file
70
ppocr/utils/formula_utils/math_txt2pkl.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import imagesize
|
||||||
|
from collections import defaultdict
|
||||||
|
import glob
|
||||||
|
from os.path import join
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def txt2pickle(images, equations, save_dir):
|
||||||
|
save_p = os.path.join(save_dir, "latexocr_{}.pkl".format(images.split("/")[-1]))
|
||||||
|
min_dimensions = (32, 32)
|
||||||
|
max_dimensions = (672, 192)
|
||||||
|
max_length = 512
|
||||||
|
data = defaultdict(lambda: [])
|
||||||
|
if images is not None and equations is not None:
|
||||||
|
images_list = [
|
||||||
|
path.replace("\\", "/") for path in glob.glob(join(images, "*.png"))
|
||||||
|
]
|
||||||
|
indices = [int(os.path.basename(img).split(".")[0]) for img in images_list]
|
||||||
|
eqs = open(equations, "r").read().split("\n")
|
||||||
|
for i, im in tqdm(enumerate(images_list), total=len(images_list)):
|
||||||
|
width, height = imagesize.get(im)
|
||||||
|
if (
|
||||||
|
min_dimensions[0] <= width <= max_dimensions[0]
|
||||||
|
and min_dimensions[1] <= height <= max_dimensions[1]
|
||||||
|
):
|
||||||
|
data[(width, height)].append((eqs[indices[i]], im))
|
||||||
|
data = dict(data)
|
||||||
|
with open(save_p, "wb") as file:
|
||||||
|
pickle.dump(data, file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_dir",
|
||||||
|
type=str,
|
||||||
|
default=".",
|
||||||
|
help="Input_label or input path to be converted",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mathtxt_path",
|
||||||
|
type=str,
|
||||||
|
default=".",
|
||||||
|
help="Input_label or input path to be converted",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir", type=str, default="out_label.txt", help="Output file name"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
txt2pickle(args.image_dir, args.mathtxt_path, args.output_dir)
|
||||||
@ -12,3 +12,6 @@ cython
|
|||||||
Pillow
|
Pillow
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
|
albumentations==1.4.10
|
||||||
|
tokenizers==0.19.1
|
||||||
|
imagesize
|
||||||
|
|||||||
@ -105,6 +105,8 @@ def main():
|
|||||||
if "model_type" in config["Architecture"].keys():
|
if "model_type" in config["Architecture"].keys():
|
||||||
if config["Architecture"]["algorithm"] == "CAN":
|
if config["Architecture"]["algorithm"] == "CAN":
|
||||||
model_type = "can"
|
model_type = "can"
|
||||||
|
elif config["Architecture"]["algorithm"] == "LaTeXOCR":
|
||||||
|
model_type = "latexocr"
|
||||||
else:
|
else:
|
||||||
model_type = config["Architecture"]["model_type"]
|
model_type = config["Architecture"]["model_type"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -131,6 +131,11 @@ def export_single_model(
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
model = to_static(model, input_spec=other_shape)
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "LaTeXOCR":
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
||||||
input_spec = [
|
input_spec = [
|
||||||
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
|
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
|
||||||
|
|||||||
@ -133,6 +133,11 @@ class TextRecognizer(object):
|
|||||||
"character_dict_path": args.rec_char_dict_path,
|
"character_dict_path": args.rec_char_dict_path,
|
||||||
"use_space_char": args.use_space_char,
|
"use_space_char": args.use_space_char,
|
||||||
}
|
}
|
||||||
|
elif self.rec_algorithm == "LaTeXOCR":
|
||||||
|
postprocess_params = {
|
||||||
|
"name": "LaTeXOCRDecode",
|
||||||
|
"rec_char_dict_path": args.rec_char_dict_path,
|
||||||
|
}
|
||||||
elif self.rec_algorithm == "ParseQ":
|
elif self.rec_algorithm == "ParseQ":
|
||||||
postprocess_params = {
|
postprocess_params = {
|
||||||
"name": "ParseQLabelDecode",
|
"name": "ParseQLabelDecode",
|
||||||
@ -450,6 +455,90 @@ class TextRecognizer(object):
|
|||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def pad_(self, img, divable=32):
|
||||||
|
threshold = 128
|
||||||
|
data = np.array(img.convert("LA"))
|
||||||
|
if data[..., -1].var() == 0:
|
||||||
|
data = (data[..., 0]).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
data = (255 - data[..., -1]).astype(np.uint8)
|
||||||
|
data = (data - data.min()) / (data.max() - data.min()) * 255
|
||||||
|
if data.mean() > threshold:
|
||||||
|
# To invert the text to white
|
||||||
|
gray = 255 * (data < threshold).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
gray = 255 * (data > threshold).astype(np.uint8)
|
||||||
|
data = 255 - data
|
||||||
|
|
||||||
|
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
|
||||||
|
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
|
||||||
|
rect = data[b : b + h, a : a + w]
|
||||||
|
im = Image.fromarray(rect).convert("L")
|
||||||
|
dims = []
|
||||||
|
for x in [w, h]:
|
||||||
|
div, mod = divmod(x, divable)
|
||||||
|
dims.append(divable * (div + (1 if mod > 0 else 0)))
|
||||||
|
padded = Image.new("L", dims, 255)
|
||||||
|
padded.paste(im, (0, 0, im.size[0], im.size[1]))
|
||||||
|
return padded
|
||||||
|
|
||||||
|
def minmax_size_(
|
||||||
|
self,
|
||||||
|
img,
|
||||||
|
max_dimensions,
|
||||||
|
min_dimensions,
|
||||||
|
):
|
||||||
|
if max_dimensions is not None:
|
||||||
|
ratios = [a / b for a, b in zip(img.size, max_dimensions)]
|
||||||
|
if any([r > 1 for r in ratios]):
|
||||||
|
size = np.array(img.size) // max(ratios)
|
||||||
|
img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
|
||||||
|
if min_dimensions is not None:
|
||||||
|
# hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
|
||||||
|
padded_size = [
|
||||||
|
max(img_dim, min_dim)
|
||||||
|
for img_dim, min_dim in zip(img.size, min_dimensions)
|
||||||
|
]
|
||||||
|
if padded_size != list(img.size): # assert hypothesis
|
||||||
|
padded_im = Image.new("L", padded_size, 255)
|
||||||
|
padded_im.paste(img, img.getbbox())
|
||||||
|
img = padded_im
|
||||||
|
return img
|
||||||
|
|
||||||
|
def norm_img_latexocr(self, img):
|
||||||
|
# CAN only predict gray scale image
|
||||||
|
shape = (1, 1, 3)
|
||||||
|
mean = [0.7931, 0.7931, 0.7931]
|
||||||
|
std = [0.1738, 0.1738, 0.1738]
|
||||||
|
scale = 255.0
|
||||||
|
min_dimensions = [32, 32]
|
||||||
|
max_dimensions = [672, 192]
|
||||||
|
mean = np.array(mean).reshape(shape).astype("float32")
|
||||||
|
std = np.array(std).reshape(shape).astype("float32")
|
||||||
|
|
||||||
|
im_h, im_w = img.shape[:2]
|
||||||
|
if (
|
||||||
|
min_dimensions[0] <= im_w <= max_dimensions[0]
|
||||||
|
and min_dimensions[1] <= im_h <= max_dimensions[1]
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
img = Image.fromarray(np.uint8(img))
|
||||||
|
img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions)
|
||||||
|
img = np.array(img)
|
||||||
|
im_h, im_w = img.shape[:2]
|
||||||
|
img = np.dstack([img, img, img])
|
||||||
|
img = (img.astype("float32") * scale - mean) / std
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
divide_h = math.ceil(im_h / 16) * 16
|
||||||
|
divide_w = math.ceil(im_w / 16) * 16
|
||||||
|
img = np.pad(
|
||||||
|
img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
|
||||||
|
)
|
||||||
|
img = img[:, :, np.newaxis].transpose(2, 0, 1)
|
||||||
|
img = img.astype("float32")
|
||||||
|
return img
|
||||||
|
|
||||||
def __call__(self, img_list):
|
def __call__(self, img_list):
|
||||||
img_num = len(img_list)
|
img_num = len(img_list)
|
||||||
# Calculate the aspect ratio of all text bars
|
# Calculate the aspect ratio of all text bars
|
||||||
@ -552,6 +641,10 @@ class TextRecognizer(object):
|
|||||||
word_label_list = []
|
word_label_list = []
|
||||||
norm_img_mask_batch.append(norm_image_mask)
|
norm_img_mask_batch.append(norm_image_mask)
|
||||||
word_label_list.append(word_label)
|
word_label_list.append(word_label)
|
||||||
|
elif self.rec_algorithm == "LaTeXOCR":
|
||||||
|
norm_img = self.norm_img_latexocr(img_list[indices[ino]])
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
norm_img_batch.append(norm_img)
|
||||||
else:
|
else:
|
||||||
norm_img = self.resize_norm_img(
|
norm_img = self.resize_norm_img(
|
||||||
img_list[indices[ino]], max_wh_ratio
|
img_list[indices[ino]], max_wh_ratio
|
||||||
@ -666,6 +759,29 @@ class TextRecognizer(object):
|
|||||||
if self.benchmark:
|
if self.benchmark:
|
||||||
self.autolog.times.stamp()
|
self.autolog.times.stamp()
|
||||||
preds = outputs
|
preds = outputs
|
||||||
|
elif self.rec_algorithm == "LaTeXOCR":
|
||||||
|
inputs = [norm_img_batch]
|
||||||
|
if self.use_onnx:
|
||||||
|
input_dict = {}
|
||||||
|
input_dict[self.input_tensor.name] = norm_img_batch
|
||||||
|
outputs = self.predictor.run(self.output_tensors, input_dict)
|
||||||
|
preds = outputs
|
||||||
|
else:
|
||||||
|
input_names = self.predictor.get_input_names()
|
||||||
|
input_tensor = []
|
||||||
|
for i in range(len(input_names)):
|
||||||
|
input_tensor_i = self.predictor.get_input_handle(input_names[i])
|
||||||
|
input_tensor_i.copy_from_cpu(inputs[i])
|
||||||
|
input_tensor.append(input_tensor_i)
|
||||||
|
self.input_tensor = input_tensor
|
||||||
|
self.predictor.run()
|
||||||
|
outputs = []
|
||||||
|
for output_tensor in self.output_tensors:
|
||||||
|
output = output_tensor.copy_to_cpu()
|
||||||
|
outputs.append(output)
|
||||||
|
if self.benchmark:
|
||||||
|
self.autolog.times.stamp()
|
||||||
|
preds = outputs
|
||||||
else:
|
else:
|
||||||
if self.use_onnx:
|
if self.use_onnx:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
@ -692,6 +808,9 @@ class TextRecognizer(object):
|
|||||||
wh_ratio_list=wh_ratio_list,
|
wh_ratio_list=wh_ratio_list,
|
||||||
max_wh_ratio=max_wh_ratio,
|
max_wh_ratio=max_wh_ratio,
|
||||||
)
|
)
|
||||||
|
elif self.postprocess_params["name"] == "LaTeXOCRDecode":
|
||||||
|
preds = [p.reshape([-1]) for p in preds]
|
||||||
|
rec_result = self.postprocess_op(preds)
|
||||||
else:
|
else:
|
||||||
rec_result = self.postprocess_op(preds)
|
rec_result = self.postprocess_op(preds)
|
||||||
for rno in range(len(rec_result)):
|
for rno in range(len(rec_result)):
|
||||||
|
|||||||
@ -183,6 +183,8 @@ def main():
|
|||||||
elif isinstance(post_result, list) and isinstance(post_result[0], int):
|
elif isinstance(post_result, list) and isinstance(post_result[0], int):
|
||||||
# for RFLearning CNT branch
|
# for RFLearning CNT branch
|
||||||
info = str(post_result[0])
|
info = str(post_result[0])
|
||||||
|
elif config["Architecture"]["algorithm"] == "LaTeXOCR":
|
||||||
|
info = str(post_result[0])
|
||||||
else:
|
else:
|
||||||
if len(post_result[0]) >= 2:
|
if len(post_result[0]) >= 2:
|
||||||
info = post_result[0][0] + "\t" + str(post_result[0][1])
|
info = post_result[0][0] + "\t" + str(post_result[0][1])
|
||||||
|
|||||||
@ -324,6 +324,8 @@ def train(
|
|||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
elif algorithm in ["CAN"]:
|
elif algorithm in ["CAN"]:
|
||||||
preds = model(batch[:3])
|
preds = model(batch[:3])
|
||||||
|
elif algorithm in ["LaTeXOCR"]:
|
||||||
|
preds = model(batch)
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
preds = to_float32(preds)
|
preds = to_float32(preds)
|
||||||
@ -339,6 +341,8 @@ def train(
|
|||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
elif algorithm in ["CAN"]:
|
elif algorithm in ["CAN"]:
|
||||||
preds = model(batch[:3])
|
preds = model(batch[:3])
|
||||||
|
elif algorithm in ["LaTeXOCR"]:
|
||||||
|
preds = model(batch)
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
@ -360,6 +364,10 @@ def train(
|
|||||||
elif algorithm in ["CAN"]:
|
elif algorithm in ["CAN"]:
|
||||||
model_type = "can"
|
model_type = "can"
|
||||||
eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
|
eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
|
||||||
|
elif algorithm in ["LaTeXOCR"]:
|
||||||
|
model_type = "latexocr"
|
||||||
|
post_result = post_process_class(preds, batch[1], mode="train")
|
||||||
|
eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
|
||||||
else:
|
else:
|
||||||
if config["Loss"]["name"] in [
|
if config["Loss"]["name"] in [
|
||||||
"MultiLoss",
|
"MultiLoss",
|
||||||
@ -600,6 +608,8 @@ def eval(
|
|||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
elif model_type in ["can"]:
|
elif model_type in ["can"]:
|
||||||
preds = model(batch[:3])
|
preds = model(batch[:3])
|
||||||
|
elif model_type in ["latexocr"]:
|
||||||
|
preds = model(batch)
|
||||||
elif model_type in ["sr"]:
|
elif model_type in ["sr"]:
|
||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
sr_img = preds["sr_img"]
|
sr_img = preds["sr_img"]
|
||||||
@ -614,6 +624,8 @@ def eval(
|
|||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
elif model_type in ["can"]:
|
elif model_type in ["can"]:
|
||||||
preds = model(batch[:3])
|
preds = model(batch[:3])
|
||||||
|
elif model_type in ["latexocr"]:
|
||||||
|
preds = model(batch)
|
||||||
elif model_type in ["sr"]:
|
elif model_type in ["sr"]:
|
||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
sr_img = preds["sr_img"]
|
sr_img = preds["sr_img"]
|
||||||
@ -640,6 +652,9 @@ def eval(
|
|||||||
eval_class(preds, batch_numpy)
|
eval_class(preds, batch_numpy)
|
||||||
elif model_type in ["can"]:
|
elif model_type in ["can"]:
|
||||||
eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
|
eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
|
||||||
|
elif model_type in ["latexocr"]:
|
||||||
|
post_result = post_process_class(preds, batch[1], "eval")
|
||||||
|
eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
|
||||||
else:
|
else:
|
||||||
post_result = post_process_class(preds, batch_numpy[1])
|
post_result = post_process_class(preds, batch_numpy[1])
|
||||||
eval_class(post_result, batch_numpy)
|
eval_class(post_result, batch_numpy)
|
||||||
@ -777,6 +792,7 @@ def preprocess(is_train=False):
|
|||||||
"SVTR_HGNet",
|
"SVTR_HGNet",
|
||||||
"ParseQ",
|
"ParseQ",
|
||||||
"CPPD",
|
"CPPD",
|
||||||
|
"LaTeXOCR",
|
||||||
]
|
]
|
||||||
|
|
||||||
if use_xpu:
|
if use_xpu:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user