mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-29 07:58:41 +00:00
Submit SR model (#6933)
* add sr model * update for eval * submit sr * polish code * polish code * polish code * update sr model * update doc * update doc * update doc * fix typo * format code * update metric * fix export
This commit is contained in:
parent
f74f897f56
commit
7054013004
85
configs/sr/sr_tsrn_transformer_strock.yml
Normal file
85
configs/sr/sr_tsrn_transformer_strock.yml
Normal file
@ -0,0 +1,85 @@
|
||||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 500
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/sr/sr_tsrn_transformer_strock/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 1000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir: sr_output
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_52.png
|
||||
# for data or label process
|
||||
character_dict_path: ./train_data/srdata/english_decomposition.txt
|
||||
max_text_length: 100
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/sr/predicts_gestalt.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.5
|
||||
beta2: 0.999
|
||||
clip_norm: 0.25
|
||||
lr:
|
||||
learning_rate: 0.0001
|
||||
|
||||
Architecture:
|
||||
model_type: sr
|
||||
algorithm: Gestalt
|
||||
Transform:
|
||||
name: TSRN
|
||||
STN: True
|
||||
infer_mode: False
|
||||
|
||||
Loss:
|
||||
name: StrokeFocusLoss
|
||||
character_dict_path: ./train_data/srdata/english_decomposition.txt
|
||||
|
||||
PostProcess:
|
||||
name: None
|
||||
|
||||
Metric:
|
||||
name: SRMetric
|
||||
main_indicator: all
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSetSR
|
||||
data_dir: ./train_data/srdata/train
|
||||
transforms:
|
||||
- SRResize:
|
||||
imgH: 32
|
||||
imgW: 128
|
||||
down_sample_scale: 2
|
||||
- SRLabelEncode: # Class handling label
|
||||
- KeepKeys:
|
||||
keep_keys: ['img_lr', 'img_hr', 'length', 'input_tensor', 'label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
batch_size_per_card: 16
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSetSR
|
||||
data_dir: ./train_data/srdata/test
|
||||
transforms:
|
||||
- SRResize:
|
||||
imgH: 32
|
||||
imgW: 128
|
||||
down_sample_scale: 2
|
||||
- SRLabelEncode: # Class handling label
|
||||
- KeepKeys:
|
||||
keep_keys: ['img_lr', 'img_hr','length', 'input_tensor', 'label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 16
|
||||
num_workers: 4
|
||||
|
||||
127
doc/doc_ch/algorithm_sr_gestalt.md
Normal file
127
doc/doc_ch/algorithm_sr_gestalt.md
Normal file
@ -0,0 +1,127 @@
|
||||
# Text Gestalt
|
||||
|
||||
- [1. 算法简介](#1)
|
||||
- [2. 环境配置](#2)
|
||||
- [3. 模型训练、评估、预测](#3)
|
||||
- [3.1 训练](#3-1)
|
||||
- [3.2 评估](#3-2)
|
||||
- [3.3 预测](#3-3)
|
||||
- [4. 推理部署](#4)
|
||||
- [4.1 Python推理](#4-1)
|
||||
- [4.2 C++推理](#4-2)
|
||||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
论文信息:
|
||||
> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf)
|
||||
|
||||
> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang
|
||||
|
||||
> AAAI, 2022
|
||||
|
||||
参考[FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) 数据下载说明,在TextZoom测试集合上超分算法效果如下:
|
||||
|
||||
|模型|骨干网络|PSNR_Avg|SSIM_Avg|配置文件|下载链接|
|
||||
|---|---|---|---|---|---|
|
||||
|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[训练模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)|
|
||||
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
|
||||
|
||||
- 训练
|
||||
|
||||
在完成数据准备后,便可以启动训练,训练命令如下:
|
||||
|
||||
```
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
|
||||
|
||||
```
|
||||
|
||||
- 评估
|
||||
|
||||
```
|
||||
# GPU 评估, Global.pretrained_model 为待测权重
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
- 预测:
|
||||
|
||||
```
|
||||
# 预测使用的配置文件必须与训练一致
|
||||
python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
|
||||
```
|
||||
|
||||

|
||||
|
||||
执行命令后,上面图像的超分结果如下:
|
||||
|
||||

|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
|
||||
首先将文本超分训练过程中保存的模型,转换成inference model。以 Text-Gestalt 训练的[模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) 为例,可以使用如下命令进行转换:
|
||||
```shell
|
||||
python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
|
||||
```
|
||||
Text-Gestalt 文本超分模型推理,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
|
||||
|
||||
```
|
||||
|
||||
执行命令后,图像的超分结果如下:
|
||||
|
||||

|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理
|
||||
|
||||
暂未支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂未支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂未支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@inproceedings{chen2022text,
|
||||
title={Text gestalt: Stroke-aware scene text image super-resolution},
|
||||
author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang},
|
||||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||
volume={36},
|
||||
number={1},
|
||||
pages={285--293},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
136
doc/doc_en/algorithm_sr_gestalt_en.md
Normal file
136
doc/doc_en/algorithm_sr_gestalt_en.md
Normal file
@ -0,0 +1,136 @@
|
||||
# Text Gestalt
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Environment](#2)
|
||||
- [3. Model Training / Evaluation / Prediction](#3)
|
||||
- [3.1 Training](#3-1)
|
||||
- [3.2 Evaluation](#3-2)
|
||||
- [3.3 Prediction](#3-3)
|
||||
- [4. Inference and Deployment](#4)
|
||||
- [4.1 Python Inference](#4-1)
|
||||
- [4.2 C++ Inference](#4-2)
|
||||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
Paper:
|
||||
> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf)
|
||||
|
||||
> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang
|
||||
|
||||
> AAAI, 2022
|
||||
|
||||
Referring to the [FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) data download instructions, the effect of the super-score algorithm on the TextZoom test set is as follows:
|
||||
|
||||
|Model|Backbone|config|Acc|Download link|
|
||||
|---|---|---|---|---|---|
|
||||
|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[train model](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)|
|
||||
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Model Training / Evaluation / Prediction
|
||||
|
||||
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different models only requires **changing the configuration file**.
|
||||
|
||||
Training:
|
||||
|
||||
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||
|
||||
```
|
||||
#Single GPU training (long training period, not recommended)
|
||||
|
||||
python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
|
||||
|
||||
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
|
||||
|
||||
```
|
||||
|
||||
|
||||
Evaluation:
|
||||
|
||||
```
|
||||
# GPU evaluation
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
Prediction:
|
||||
|
||||
```
|
||||
# The configuration file used for prediction must match the training
|
||||
|
||||
python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
|
||||
```
|
||||
|
||||

|
||||
|
||||
After executing the command, the super-resolution result of the above image is as follows:
|
||||
|
||||

|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Inference and Deployment
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
|
||||
First, the model saved during the training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) ), you can use the following command to convert:
|
||||
|
||||
```shell
|
||||
python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
|
||||
```
|
||||
|
||||
For Text-Gestalt super-resolution model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
|
||||
|
||||
```
|
||||
|
||||
After executing the command, the super-resolution result of the above image is as follows:
|
||||
|
||||

|
||||
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 More
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{chen2022text,
|
||||
title={Text gestalt: Stroke-aware scene text image super-resolution},
|
||||
author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang},
|
||||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||
volume={36},
|
||||
number={1},
|
||||
pages={285--293},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
@ -34,7 +34,7 @@ import paddle.distributed as dist
|
||||
|
||||
from ppocr.data.imaug import transform, create_operators
|
||||
from ppocr.data.simple_dataset import SimpleDataSet
|
||||
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||
from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR
|
||||
from ppocr.data.pgnet_dataset import PGDataSet
|
||||
from ppocr.data.pubtab_dataset import PubTabDataSet
|
||||
|
||||
@ -54,7 +54,8 @@ def build_dataloader(config, mode, device, logger, seed=None):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
support_dict = [
|
||||
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'
|
||||
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
|
||||
'LMDBDataSetSR'
|
||||
]
|
||||
module_name = config[mode]['dataset']['name']
|
||||
assert module_name in support_dict, Exception(
|
||||
|
||||
@ -1236,6 +1236,54 @@ class ABINetLabelEncode(BaseRecLabelEncode):
|
||||
return dict_character
|
||||
|
||||
|
||||
class SRLabelEncode(BaseRecLabelEncode):
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SRLabelEncode, self).__init__(max_text_length,
|
||||
character_dict_path, use_space_char)
|
||||
self.dic = {}
|
||||
with open(character_dict_path, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
line = line.strip()
|
||||
character, sequence = line.split()
|
||||
self.dic[character] = sequence
|
||||
english_stroke_alphabet = '0123456789'
|
||||
self.english_stroke_dict = {}
|
||||
for index in range(len(english_stroke_alphabet)):
|
||||
self.english_stroke_dict[english_stroke_alphabet[index]] = index
|
||||
|
||||
def encode(self, label):
|
||||
stroke_sequence = ''
|
||||
for character in label:
|
||||
if character not in self.dic:
|
||||
continue
|
||||
else:
|
||||
stroke_sequence += self.dic[character]
|
||||
stroke_sequence += '0'
|
||||
label = stroke_sequence
|
||||
|
||||
length = len(label)
|
||||
|
||||
input_tensor = np.zeros(self.max_text_len).astype("int64")
|
||||
for j in range(length - 1):
|
||||
input_tensor[j + 1] = self.english_stroke_dict[label[j]]
|
||||
|
||||
return length, input_tensor
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
length, input_tensor = self.encode(text)
|
||||
|
||||
data["length"] = length
|
||||
data["input_tensor"] = input_tensor
|
||||
if text is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
class SPINLabelEncode(AttnLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
@ -440,3 +441,52 @@ class KieResize(object):
|
||||
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
|
||||
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
|
||||
return points
|
||||
|
||||
|
||||
class SRResize(object):
|
||||
def __init__(self,
|
||||
imgH=32,
|
||||
imgW=128,
|
||||
down_sample_scale=4,
|
||||
keep_ratio=False,
|
||||
min_ratio=1,
|
||||
mask=False,
|
||||
infer_mode=False,
|
||||
**kwargs):
|
||||
self.imgH = imgH
|
||||
self.imgW = imgW
|
||||
self.keep_ratio = keep_ratio
|
||||
self.min_ratio = min_ratio
|
||||
self.down_sample_scale = down_sample_scale
|
||||
self.mask = mask
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
imgH = self.imgH
|
||||
imgW = self.imgW
|
||||
images_lr = data["image_lr"]
|
||||
transform2 = ResizeNormalize(
|
||||
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
|
||||
images_lr = transform2(images_lr)
|
||||
data["img_lr"] = images_lr
|
||||
if self.infer_mode:
|
||||
return data
|
||||
|
||||
images_HR = data["image_hr"]
|
||||
label_strs = data["label"]
|
||||
transform = ResizeNormalize((imgW, imgH))
|
||||
images_HR = transform(images_HR)
|
||||
data["img_hr"] = images_HR
|
||||
return data
|
||||
|
||||
|
||||
class ResizeNormalize(object):
|
||||
def __init__(self, size, interpolation=Image.BICUBIC):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
img = img.resize(self.size, self.interpolation)
|
||||
img_numpy = np.array(img).astype("float32")
|
||||
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
|
||||
return img_numpy
|
||||
|
||||
@ -16,6 +16,9 @@ import os
|
||||
from paddle.io import Dataset
|
||||
import lmdb
|
||||
import cv2
|
||||
import string
|
||||
import six
|
||||
from PIL import Image
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
|
||||
@ -116,3 +119,58 @@ class LMDBDataSet(Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return self.data_idx_order_list.shape[0]
|
||||
|
||||
|
||||
class LMDBDataSetSR(LMDBDataSet):
|
||||
def buf2PIL(self, txn, key, type='RGB'):
|
||||
imgbuf = txn.get(key)
|
||||
buf = six.BytesIO()
|
||||
buf.write(imgbuf)
|
||||
buf.seek(0)
|
||||
im = Image.open(buf).convert(type)
|
||||
return im
|
||||
|
||||
def str_filt(self, str_, voc_type):
|
||||
alpha_dict = {
|
||||
'digit': string.digits,
|
||||
'lower': string.digits + string.ascii_lowercase,
|
||||
'upper': string.digits + string.ascii_letters,
|
||||
'all': string.digits + string.ascii_letters + string.punctuation
|
||||
}
|
||||
if voc_type == 'lower':
|
||||
str_ = str_.lower()
|
||||
for char in str_:
|
||||
if char not in alpha_dict[voc_type]:
|
||||
str_ = str_.replace(char, '')
|
||||
return str_
|
||||
|
||||
def get_lmdb_sample_info(self, txn, index):
|
||||
self.voc_type = 'upper'
|
||||
self.max_len = 100
|
||||
self.test = False
|
||||
label_key = b'label-%09d' % index
|
||||
word = str(txn.get(label_key).decode())
|
||||
img_HR_key = b'image_hr-%09d' % index # 128*32
|
||||
img_lr_key = b'image_lr-%09d' % index # 64*16
|
||||
try:
|
||||
img_HR = self.buf2PIL(txn, img_HR_key, 'RGB')
|
||||
img_lr = self.buf2PIL(txn, img_lr_key, 'RGB')
|
||||
except IOError or len(word) > self.max_len:
|
||||
return self[index + 1]
|
||||
label_str = self.str_filt(word, self.voc_type)
|
||||
return img_HR, img_lr, label_str
|
||||
|
||||
def __getitem__(self, idx):
|
||||
lmdb_idx, file_idx = self.data_idx_order_list[idx]
|
||||
lmdb_idx = int(lmdb_idx)
|
||||
file_idx = int(file_idx)
|
||||
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
|
||||
file_idx)
|
||||
if sample_info is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
img_HR, img_lr, label_str = sample_info
|
||||
data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str}
|
||||
outs = transform(data, self.ops)
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
@ -57,6 +57,9 @@ from .table_master_loss import TableMasterLoss
|
||||
# vqa token loss
|
||||
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
|
||||
|
||||
# sr loss
|
||||
from .stroke_focus_loss import StrokeFocusLoss
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
support_dict = [
|
||||
@ -64,7 +67,7 @@ def build_loss(config):
|
||||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss'
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss','StrokeFocusLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
||||
68
ppocr/losses/stroke_focus_loss.py
Normal file
68
ppocr/losses/stroke_focus_loss.py
Normal file
@ -0,0 +1,68 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/stroke_focus_loss.py
|
||||
"""
|
||||
import cv2
|
||||
import sys
|
||||
import time
|
||||
import string
|
||||
import random
|
||||
import numpy as np
|
||||
import paddle.nn as nn
|
||||
import paddle
|
||||
|
||||
|
||||
class StrokeFocusLoss(nn.Layer):
|
||||
def __init__(self, character_dict_path=None, **kwargs):
|
||||
super(StrokeFocusLoss, self).__init__(character_dict_path)
|
||||
self.mse_loss = nn.MSELoss()
|
||||
self.ce_loss = nn.CrossEntropyLoss()
|
||||
self.l1_loss = nn.L1Loss()
|
||||
self.english_stroke_alphabet = '0123456789'
|
||||
self.english_stroke_dict = {}
|
||||
for index in range(len(self.english_stroke_alphabet)):
|
||||
self.english_stroke_dict[self.english_stroke_alphabet[
|
||||
index]] = index
|
||||
|
||||
stroke_decompose_lines = open(character_dict_path, 'r').readlines()
|
||||
self.dic = {}
|
||||
for line in stroke_decompose_lines:
|
||||
line = line.strip()
|
||||
character, sequence = line.split()
|
||||
self.dic[character] = sequence
|
||||
|
||||
def forward(self, pred, data):
|
||||
|
||||
sr_img = pred["sr_img"]
|
||||
hr_img = pred["hr_img"]
|
||||
|
||||
mse_loss = self.mse_loss(sr_img, hr_img)
|
||||
word_attention_map_gt = pred["word_attention_map_gt"]
|
||||
word_attention_map_pred = pred["word_attention_map_pred"]
|
||||
|
||||
hr_pred = pred["hr_pred"]
|
||||
sr_pred = pred["sr_pred"]
|
||||
|
||||
attention_loss = paddle.nn.functional.l1_loss(word_attention_map_gt,
|
||||
word_attention_map_pred)
|
||||
|
||||
loss = (mse_loss + attention_loss * 50) * 100
|
||||
|
||||
return {
|
||||
"mse_loss": mse_loss,
|
||||
"attention_loss": attention_loss,
|
||||
"loss": loss
|
||||
}
|
||||
@ -30,13 +30,13 @@ from .table_metric import TableMetric
|
||||
from .kie_metric import KIEMetric
|
||||
from .vqa_token_ser_metric import VQASerTokenMetric
|
||||
from .vqa_token_re_metric import VQAReTokenMetric
|
||||
|
||||
from .sr_metric import SRMetric
|
||||
|
||||
def build_metric(config):
|
||||
support_dict = [
|
||||
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
|
||||
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
|
||||
'VQAReTokenMetric'
|
||||
'VQAReTokenMetric', 'SRMetric'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
@ -16,6 +16,7 @@ import Levenshtein
|
||||
import string
|
||||
|
||||
|
||||
|
||||
class RecMetric(object):
|
||||
def __init__(self,
|
||||
main_indicator='acc',
|
||||
|
||||
155
ppocr/metrics/sr_metric.py
Normal file
155
ppocr/metrics/sr_metric.py
Normal file
@ -0,0 +1,155 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/utils/ssim_psnr.py
|
||||
"""
|
||||
|
||||
from math import exp
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
import paddle.nn as nn
|
||||
import string
|
||||
|
||||
|
||||
class SSIM(nn.Layer):
|
||||
def __init__(self, window_size=11, size_average=True):
|
||||
super(SSIM, self).__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 1
|
||||
self.window = self.create_window(window_size, self.channel)
|
||||
|
||||
def gaussian(self, window_size, sigma):
|
||||
gauss = paddle.to_tensor([
|
||||
exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
|
||||
for x in range(window_size)
|
||||
])
|
||||
return gauss / gauss.sum()
|
||||
|
||||
def create_window(self, window_size, channel):
|
||||
_1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
|
||||
window = _2D_window.expand([channel, 1, window_size, window_size])
|
||||
return window
|
||||
|
||||
def _ssim(self, img1, img2, window, window_size, channel,
|
||||
size_average=True):
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(
|
||||
img1 * img1, window, padding=window_size // 2,
|
||||
groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(
|
||||
img2 * img2, window, padding=window_size // 2,
|
||||
groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(
|
||||
img1 * img2, window, padding=window_size // 2,
|
||||
groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
|
||||
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
else:
|
||||
return ssim_map.mean([1, 2, 3])
|
||||
|
||||
def ssim(self, img1, img2, window_size=11, size_average=True):
|
||||
(_, channel, _, _) = img1.shape
|
||||
window = self.create_window(window_size, channel)
|
||||
|
||||
return self._ssim(img1, img2, window, window_size, channel,
|
||||
size_average)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
(_, channel, _, _) = img1.shape
|
||||
|
||||
if channel == self.channel and self.window.dtype == img1.dtype:
|
||||
window = self.window
|
||||
else:
|
||||
window = self.create_window(self.window_size, channel)
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
return self._ssim(img1, img2, window, self.window_size, channel,
|
||||
self.size_average)
|
||||
|
||||
|
||||
class SRMetric(object):
|
||||
def __init__(self, main_indicator='all', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.eps = 1e-5
|
||||
self.psnr_result = []
|
||||
self.ssim_result = []
|
||||
self.calculate_ssim = SSIM()
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
self.norm_edit_dis = 0
|
||||
self.psnr_result = []
|
||||
self.ssim_result = []
|
||||
|
||||
def calculate_psnr(self, img1, img2):
|
||||
# img1 and img2 have range [0, 1]
|
||||
mse = ((img1 * 255 - img2 * 255)**2).mean()
|
||||
if mse == 0:
|
||||
return float('inf')
|
||||
return 20 * paddle.log10(255.0 / paddle.sqrt(mse))
|
||||
|
||||
def _normalize_text(self, text):
|
||||
text = ''.join(
|
||||
filter(lambda x: x in (string.digits + string.ascii_letters), text))
|
||||
return text.lower()
|
||||
|
||||
def __call__(self, pred_label, *args, **kwargs):
|
||||
metric = {}
|
||||
images_sr = pred_label["sr_img"]
|
||||
images_hr = pred_label["hr_img"]
|
||||
psnr = self.calculate_psnr(images_sr, images_hr)
|
||||
ssim = self.calculate_ssim(images_sr, images_hr)
|
||||
self.psnr_result.append(psnr)
|
||||
self.ssim_result.append(ssim)
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0,
|
||||
'norm_edit_dis': 0,
|
||||
}
|
||||
"""
|
||||
self.psnr_avg = sum(self.psnr_result) / len(self.psnr_result)
|
||||
self.psnr_avg = round(self.psnr_avg.item(), 6)
|
||||
self.ssim_avg = sum(self.ssim_result) / len(self.ssim_result)
|
||||
self.ssim_avg = round(self.ssim_avg.item(), 6)
|
||||
|
||||
self.all_avg = self.psnr_avg + self.ssim_avg
|
||||
|
||||
self.reset()
|
||||
return {
|
||||
'psnr_avg': self.psnr_avg,
|
||||
"ssim_avg": self.ssim_avg,
|
||||
"all": self.all_avg
|
||||
}
|
||||
@ -14,6 +14,7 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
from ppocr.modeling.transforms import build_transform
|
||||
from ppocr.modeling.backbones import build_backbone
|
||||
@ -46,9 +47,13 @@ class BaseModel(nn.Layer):
|
||||
in_channels = self.transform.out_channels
|
||||
|
||||
# build backbone, backbone is need for del, rec and cls
|
||||
config["Backbone"]['in_channels'] = in_channels
|
||||
self.backbone = build_backbone(config["Backbone"], model_type)
|
||||
in_channels = self.backbone.out_channels
|
||||
if 'Backbone' not in config or config['Backbone'] is None:
|
||||
self.use_backbone = False
|
||||
else:
|
||||
self.use_backbone = True
|
||||
config["Backbone"]['in_channels'] = in_channels
|
||||
self.backbone = build_backbone(config["Backbone"], model_type)
|
||||
in_channels = self.backbone.out_channels
|
||||
|
||||
# build neck
|
||||
# for rec, neck can be cnn,rnn or reshape(None)
|
||||
@ -77,7 +82,8 @@ class BaseModel(nn.Layer):
|
||||
y = dict()
|
||||
if self.use_transform:
|
||||
x = self.transform(x)
|
||||
x = self.backbone(x)
|
||||
if self.use_backbone:
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, dict):
|
||||
y.update(x)
|
||||
else:
|
||||
@ -109,4 +115,4 @@ class BaseModel(nn.Layer):
|
||||
else:
|
||||
return {final_name: x}
|
||||
else:
|
||||
return x
|
||||
return x
|
||||
430
ppocr/modeling/heads/sr_rensnet_transformer.py
Normal file
430
ppocr/modeling/heads/sr_rensnet_transformer.py
Normal file
@ -0,0 +1,430 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
|
||||
"""
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import math, copy
|
||||
import numpy as np
|
||||
|
||||
# stroke-level alphabet
|
||||
alphabet = '0123456789'
|
||||
|
||||
|
||||
def get_alphabet_len():
|
||||
return len(alphabet)
|
||||
|
||||
|
||||
def subsequent_mask(size):
|
||||
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
||||
Unmasked positions are filled with float(0.0).
|
||||
"""
|
||||
mask = paddle.ones([1, size, size], dtype='float32')
|
||||
mask_inf = paddle.triu(
|
||||
paddle.full(
|
||||
shape=[1, size, size], dtype='float32', fill_value='-inf'),
|
||||
diagonal=1)
|
||||
mask = mask + mask_inf
|
||||
padding_mask = paddle.equal(mask, paddle.to_tensor(1, dtype=mask.dtype))
|
||||
return padding_mask
|
||||
|
||||
|
||||
def clones(module, N):
|
||||
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
|
||||
|
||||
|
||||
def masked_fill(x, mask, value):
|
||||
y = paddle.full(x.shape, value, x.dtype)
|
||||
return paddle.where(mask, y, x)
|
||||
|
||||
|
||||
def attention(query, key, value, mask=None, dropout=None, attention_map=None):
|
||||
d_k = query.shape[-1]
|
||||
scores = paddle.matmul(query,
|
||||
paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k)
|
||||
|
||||
if mask is not None:
|
||||
scores = masked_fill(scores, mask == 0, float('-inf'))
|
||||
else:
|
||||
pass
|
||||
|
||||
p_attn = F.softmax(scores, axis=-1)
|
||||
|
||||
if dropout is not None:
|
||||
p_attn = dropout(p_attn)
|
||||
return paddle.matmul(p_attn, value), p_attn
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Layer):
|
||||
def __init__(self, h, d_model, dropout=0.1, compress_attention=False):
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
assert d_model % h == 0
|
||||
self.d_k = d_model // h
|
||||
self.h = h
|
||||
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
|
||||
self.compress_attention = compress_attention
|
||||
self.compress_attention_linear = nn.Linear(h, 1)
|
||||
|
||||
def forward(self, query, key, value, mask=None, attention_map=None):
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1)
|
||||
nbatches = query.shape[0]
|
||||
|
||||
query, key, value = \
|
||||
[paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
|
||||
for l, x in zip(self.linears, (query, key, value))]
|
||||
|
||||
x, attention_map = attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask=mask,
|
||||
dropout=self.dropout,
|
||||
attention_map=attention_map)
|
||||
|
||||
x = paddle.reshape(
|
||||
paddle.transpose(x, [0, 2, 1, 3]),
|
||||
[nbatches, -1, self.h * self.d_k])
|
||||
|
||||
return self.linears[-1](x), attention_map
|
||||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self, num_in, block, layers):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2D(num_in, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.bn1 = nn.BatchNorm2D(64, use_global_stats=True)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.pool = nn.MaxPool2D((2, 2), (2, 2))
|
||||
|
||||
self.conv2 = nn.Conv2D(64, 128, kernel_size=3, stride=1, padding=1)
|
||||
self.bn2 = nn.BatchNorm2D(128, use_global_stats=True)
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
self.layer1_pool = nn.MaxPool2D((2, 2), (2, 2))
|
||||
self.layer1 = self._make_layer(block, 128, 256, layers[0])
|
||||
self.layer1_conv = nn.Conv2D(256, 256, 3, 1, 1)
|
||||
self.layer1_bn = nn.BatchNorm2D(256, use_global_stats=True)
|
||||
self.layer1_relu = nn.ReLU()
|
||||
|
||||
self.layer2_pool = nn.MaxPool2D((2, 2), (2, 2))
|
||||
self.layer2 = self._make_layer(block, 256, 256, layers[1])
|
||||
self.layer2_conv = nn.Conv2D(256, 256, 3, 1, 1)
|
||||
self.layer2_bn = nn.BatchNorm2D(256, use_global_stats=True)
|
||||
self.layer2_relu = nn.ReLU()
|
||||
|
||||
self.layer3_pool = nn.MaxPool2D((2, 2), (2, 2))
|
||||
self.layer3 = self._make_layer(block, 256, 512, layers[2])
|
||||
self.layer3_conv = nn.Conv2D(512, 512, 3, 1, 1)
|
||||
self.layer3_bn = nn.BatchNorm2D(512, use_global_stats=True)
|
||||
self.layer3_relu = nn.ReLU()
|
||||
|
||||
self.layer4_pool = nn.MaxPool2D((2, 2), (2, 2))
|
||||
self.layer4 = self._make_layer(block, 512, 512, layers[3])
|
||||
self.layer4_conv2 = nn.Conv2D(512, 1024, 3, 1, 1)
|
||||
self.layer4_conv2_bn = nn.BatchNorm2D(1024, use_global_stats=True)
|
||||
self.layer4_conv2_relu = nn.ReLU()
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, blocks):
|
||||
|
||||
if inplanes != planes:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2D(inplanes, planes, 3, 1, 1),
|
||||
nn.BatchNorm2D(
|
||||
planes, use_global_stats=True), )
|
||||
else:
|
||||
downsample = None
|
||||
layers = []
|
||||
layers.append(block(inplanes, planes, downsample))
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(planes, planes, downsample=None))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.pool(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu2(x)
|
||||
|
||||
x = self.layer1_pool(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer1_conv(x)
|
||||
x = self.layer1_bn(x)
|
||||
x = self.layer1_relu(x)
|
||||
|
||||
x = self.layer2(x)
|
||||
x = self.layer2_conv(x)
|
||||
x = self.layer2_bn(x)
|
||||
x = self.layer2_relu(x)
|
||||
|
||||
x = self.layer3(x)
|
||||
x = self.layer3_conv(x)
|
||||
x = self.layer3_bn(x)
|
||||
x = self.layer3_relu(x)
|
||||
|
||||
x = self.layer4(x)
|
||||
x = self.layer4_conv2(x)
|
||||
x = self.layer4_conv2_bn(x)
|
||||
x = self.layer4_conv2_relu(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Bottleneck(nn.Layer):
|
||||
def __init__(self, input_dim):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2D(input_dim, input_dim, 1)
|
||||
self.bn1 = nn.BatchNorm2D(input_dim, use_global_stats=True)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.conv2 = nn.Conv2D(input_dim, input_dim, 3, 1, 1)
|
||||
self.bn2 = nn.BatchNorm2D(input_dim, use_global_stats=True)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Layer):
|
||||
"Implement the PE function."
|
||||
|
||||
def __init__(self, dropout, dim, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
|
||||
|
||||
pe = paddle.zeros([max_len, dim])
|
||||
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
||||
div_term = paddle.exp(
|
||||
paddle.arange(0, dim, 2).astype('float32') *
|
||||
(-math.log(10000.0) / dim))
|
||||
pe[:, 0::2] = paddle.sin(position * div_term)
|
||||
pe[:, 1::2] = paddle.cos(position * div_term)
|
||||
pe = paddle.unsqueeze(pe, 0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:, :paddle.shape(x)[1]]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Layer):
|
||||
"Implements FFN equation."
|
||||
|
||||
def __init__(self, d_model, d_ff, dropout=0.1):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = nn.Linear(d_model, d_ff)
|
||||
self.w_2 = nn.Linear(d_ff, d_model)
|
||||
self.dropout = nn.Dropout(dropout, mode="downscale_in_infer")
|
||||
|
||||
def forward(self, x):
|
||||
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
||||
|
||||
|
||||
class Generator(nn.Layer):
|
||||
"Define standard linear + softmax generation step."
|
||||
|
||||
def __init__(self, d_model, vocab):
|
||||
super(Generator, self).__init__()
|
||||
self.proj = nn.Linear(d_model, vocab)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.proj(x)
|
||||
return out
|
||||
|
||||
|
||||
class Embeddings(nn.Layer):
|
||||
def __init__(self, d_model, vocab):
|
||||
super(Embeddings, self).__init__()
|
||||
self.lut = nn.Embedding(vocab, d_model)
|
||||
self.d_model = d_model
|
||||
|
||||
def forward(self, x):
|
||||
embed = self.lut(x) * math.sqrt(self.d_model)
|
||||
return embed
|
||||
|
||||
|
||||
class LayerNorm(nn.Layer):
|
||||
"Construct a layernorm module (See citation for details)."
|
||||
|
||||
def __init__(self, features, eps=1e-6):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.a_2 = self.create_parameter(
|
||||
shape=[features],
|
||||
default_initializer=paddle.nn.initializer.Constant(1.0))
|
||||
self.b_2 = self.create_parameter(
|
||||
shape=[features],
|
||||
default_initializer=paddle.nn.initializer.Constant(0.0))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
mean = x.mean(-1, keepdim=True)
|
||||
std = x.std(-1, keepdim=True)
|
||||
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
||||
|
||||
|
||||
class Decoder(nn.Layer):
|
||||
def __init__(self):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
self.mask_multihead = MultiHeadedAttention(
|
||||
h=16, d_model=1024, dropout=0.1)
|
||||
self.mul_layernorm1 = LayerNorm(1024)
|
||||
|
||||
self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
|
||||
self.mul_layernorm2 = LayerNorm(1024)
|
||||
|
||||
self.pff = PositionwiseFeedForward(1024, 2048)
|
||||
self.mul_layernorm3 = LayerNorm(1024)
|
||||
|
||||
def forward(self, text, conv_feature, attention_map=None):
|
||||
text_max_length = text.shape[1]
|
||||
mask = subsequent_mask(text_max_length)
|
||||
result = text
|
||||
result = self.mul_layernorm1(result + self.mask_multihead(
|
||||
text, text, text, mask=mask)[0])
|
||||
b, c, h, w = conv_feature.shape
|
||||
conv_feature = paddle.transpose(
|
||||
conv_feature.reshape([b, c, h * w]), [0, 2, 1])
|
||||
word_image_align, attention_map = self.multihead(
|
||||
result,
|
||||
conv_feature,
|
||||
conv_feature,
|
||||
mask=None,
|
||||
attention_map=attention_map)
|
||||
result = self.mul_layernorm2(result + word_image_align)
|
||||
result = self.mul_layernorm3(result + self.pff(result))
|
||||
|
||||
return result, attention_map
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self, inplanes, planes, downsample):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2D(
|
||||
inplanes, planes, kernel_size=3, stride=1, padding=1)
|
||||
self.bn1 = nn.BatchNorm2D(planes, use_global_stats=True)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = nn.Conv2D(
|
||||
planes, planes, kernel_size=3, stride=1, padding=1)
|
||||
self.bn2 = nn.BatchNorm2D(planes, use_global_stats=True)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample != None:
|
||||
residual = self.downsample(residual)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Encoder(nn.Layer):
|
||||
def __init__(self):
|
||||
super(Encoder, self).__init__()
|
||||
self.cnn = ResNet(num_in=1, block=BasicBlock, layers=[1, 2, 5, 3])
|
||||
|
||||
def forward(self, input):
|
||||
conv_result = self.cnn(input)
|
||||
return conv_result
|
||||
|
||||
|
||||
class Transformer(nn.Layer):
|
||||
def __init__(self, in_channels=1):
|
||||
super(Transformer, self).__init__()
|
||||
|
||||
word_n_class = get_alphabet_len()
|
||||
self.embedding_word_with_upperword = Embeddings(512, word_n_class)
|
||||
self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
|
||||
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.generator_word_with_upperword = Generator(1024, word_n_class)
|
||||
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.initializer.XavierNormal(p)
|
||||
|
||||
def forward(self, image, text_length, text_input, attention_map=None):
|
||||
if image.shape[1] == 3:
|
||||
R = image[:, 0:1, :, :]
|
||||
G = image[:, 1:2, :, :]
|
||||
B = image[:, 2:3, :, :]
|
||||
image = 0.299 * R + 0.587 * G + 0.114 * B
|
||||
|
||||
conv_feature = self.encoder(image) # batch, 1024, 8, 32
|
||||
max_length = max(text_length)
|
||||
text_input = text_input[:, :max_length]
|
||||
|
||||
text_embedding = self.embedding_word_with_upperword(
|
||||
text_input) # batch, text_max_length, 512
|
||||
postion_embedding = self.pe(
|
||||
paddle.zeros(text_embedding.shape)) # batch, text_max_length, 512
|
||||
text_input_with_pe = paddle.concat([text_embedding, postion_embedding],
|
||||
2) # batch, text_max_length, 1024
|
||||
batch, seq_len, _ = text_input_with_pe.shape
|
||||
|
||||
text_input_with_pe, word_attention_map = self.decoder(
|
||||
text_input_with_pe, conv_feature)
|
||||
|
||||
word_decoder_result = self.generator_word_with_upperword(
|
||||
text_input_with_pe)
|
||||
|
||||
if self.training:
|
||||
total_length = paddle.sum(text_length)
|
||||
probs_res = paddle.zeros([total_length, get_alphabet_len()])
|
||||
start = 0
|
||||
|
||||
for index, length in enumerate(text_length):
|
||||
length = int(length.numpy())
|
||||
probs_res[start:start + length, :] = word_decoder_result[
|
||||
index, 0:0 + length, :]
|
||||
|
||||
start = start + length
|
||||
|
||||
return probs_res, word_attention_map, None
|
||||
else:
|
||||
return word_decoder_result
|
||||
@ -18,10 +18,10 @@ __all__ = ['build_transform']
|
||||
def build_transform(config):
|
||||
from .tps import TPS
|
||||
from .stn import STN_ON
|
||||
from .tsrn import TSRN
|
||||
from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
|
||||
|
||||
|
||||
support_dict = ['TPS', 'STN_ON', 'GA_SPIN']
|
||||
support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception(
|
||||
|
||||
@ -153,4 +153,4 @@ class TPSSpatialTransformer(nn.Layer):
|
||||
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
||||
grid = 2.0 * grid - 1.0
|
||||
output_maps = grid_sample(input, grid, canvas=None)
|
||||
return output_maps, source_coordinate
|
||||
return output_maps, source_coordinate
|
||||
219
ppocr/modeling/transforms/tsrn.py
Normal file
219
ppocr/modeling/transforms/tsrn.py
Normal file
@ -0,0 +1,219 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/model/tsrn.py
|
||||
"""
|
||||
|
||||
import math
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
from collections import OrderedDict
|
||||
import sys
|
||||
import numpy as np
|
||||
import warnings
|
||||
import math, copy
|
||||
import cv2
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from .tps_spatial_transformer import TPSSpatialTransformer
|
||||
from .stn import STN as STN_model
|
||||
from ppocr.modeling.heads.sr_rensnet_transformer import Transformer
|
||||
|
||||
|
||||
class TSRN(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
scale_factor=2,
|
||||
width=128,
|
||||
height=32,
|
||||
STN=False,
|
||||
srb_nums=5,
|
||||
mask=False,
|
||||
hidden_units=32,
|
||||
infer_mode=False,
|
||||
**kwargs):
|
||||
super(TSRN, self).__init__()
|
||||
in_planes = 3
|
||||
if mask:
|
||||
in_planes = 4
|
||||
assert math.log(scale_factor, 2) % 1 == 0
|
||||
upsample_block_num = int(math.log(scale_factor, 2))
|
||||
self.block1 = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
in_planes, 2 * hidden_units, kernel_size=9, padding=4),
|
||||
nn.PReLU())
|
||||
self.srb_nums = srb_nums
|
||||
for i in range(srb_nums):
|
||||
setattr(self, 'block%d' % (i + 2),
|
||||
RecurrentResidualBlock(2 * hidden_units))
|
||||
|
||||
setattr(
|
||||
self,
|
||||
'block%d' % (srb_nums + 2),
|
||||
nn.Sequential(
|
||||
nn.Conv2D(
|
||||
2 * hidden_units,
|
||||
2 * hidden_units,
|
||||
kernel_size=3,
|
||||
padding=1),
|
||||
nn.BatchNorm2D(2 * hidden_units)))
|
||||
|
||||
block_ = [
|
||||
UpsampleBLock(2 * hidden_units, 2)
|
||||
for _ in range(upsample_block_num)
|
||||
]
|
||||
block_.append(
|
||||
nn.Conv2D(
|
||||
2 * hidden_units, in_planes, kernel_size=9, padding=4))
|
||||
setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
|
||||
self.tps_inputsize = [height // scale_factor, width // scale_factor]
|
||||
tps_outputsize = [height // scale_factor, width // scale_factor]
|
||||
num_control_points = 20
|
||||
tps_margins = [0.05, 0.05]
|
||||
self.stn = STN
|
||||
if self.stn:
|
||||
self.tps = TPSSpatialTransformer(
|
||||
output_image_size=tuple(tps_outputsize),
|
||||
num_control_points=num_control_points,
|
||||
margins=tuple(tps_margins))
|
||||
|
||||
self.stn_head = STN_model(
|
||||
in_channels=in_planes,
|
||||
num_ctrlpoints=num_control_points,
|
||||
activation='none')
|
||||
self.out_channels = in_channels
|
||||
|
||||
self.r34_transformer = Transformer()
|
||||
for param in self.r34_transformer.parameters():
|
||||
param.trainable = False
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def forward(self, x):
|
||||
output = {}
|
||||
if self.infer_mode:
|
||||
output["lr_img"] = x
|
||||
y = x
|
||||
else:
|
||||
output["lr_img"] = x[0]
|
||||
output["hr_img"] = x[1]
|
||||
y = x[0]
|
||||
if self.stn and self.training:
|
||||
_, ctrl_points_x = self.stn_head(y)
|
||||
y, _ = self.tps(y, ctrl_points_x)
|
||||
block = {'1': self.block1(y)}
|
||||
for i in range(self.srb_nums + 1):
|
||||
block[str(i + 2)] = getattr(self,
|
||||
'block%d' % (i + 2))(block[str(i + 1)])
|
||||
|
||||
block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
|
||||
((block['1'] + block[str(self.srb_nums + 2)]))
|
||||
|
||||
sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
|
||||
|
||||
output["sr_img"] = sr_img
|
||||
|
||||
if self.training:
|
||||
hr_img = x[1]
|
||||
length = x[2]
|
||||
input_tensor = x[3]
|
||||
|
||||
# add transformer
|
||||
sr_pred, word_attention_map_pred, _ = self.r34_transformer(
|
||||
sr_img, length, input_tensor)
|
||||
|
||||
hr_pred, word_attention_map_gt, _ = self.r34_transformer(
|
||||
hr_img, length, input_tensor)
|
||||
|
||||
output["hr_img"] = hr_img
|
||||
output["hr_pred"] = hr_pred
|
||||
output["word_attention_map_gt"] = word_attention_map_gt
|
||||
output["sr_pred"] = sr_pred
|
||||
output["word_attention_map_pred"] = word_attention_map_pred
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RecurrentResidualBlock(nn.Layer):
|
||||
def __init__(self, channels):
|
||||
super(RecurrentResidualBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm2D(channels)
|
||||
self.gru1 = GruBlock(channels, channels)
|
||||
self.prelu = mish()
|
||||
self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm2D(channels)
|
||||
self.gru2 = GruBlock(channels, channels)
|
||||
|
||||
def forward(self, x):
|
||||
residual = self.conv1(x)
|
||||
residual = self.bn1(residual)
|
||||
residual = self.prelu(residual)
|
||||
residual = self.conv2(residual)
|
||||
residual = self.bn2(residual)
|
||||
residual = self.gru1(residual.transpose([0, 1, 3, 2])).transpose(
|
||||
[0, 1, 3, 2])
|
||||
|
||||
return self.gru2(x + residual)
|
||||
|
||||
|
||||
class UpsampleBLock(nn.Layer):
|
||||
def __init__(self, in_channels, up_scale):
|
||||
super(UpsampleBLock, self).__init__()
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels, in_channels * up_scale**2, kernel_size=3, padding=1)
|
||||
|
||||
self.pixel_shuffle = nn.PixelShuffle(up_scale)
|
||||
self.prelu = mish()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.prelu(x)
|
||||
return x
|
||||
|
||||
|
||||
class mish(nn.Layer):
|
||||
def __init__(self, ):
|
||||
super(mish, self).__init__()
|
||||
self.activated = True
|
||||
|
||||
def forward(self, x):
|
||||
if self.activated:
|
||||
x = x * (paddle.tanh(F.softplus(x)))
|
||||
return x
|
||||
|
||||
|
||||
class GruBlock(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(GruBlock, self).__init__()
|
||||
assert out_channels % 2 == 0
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels, out_channels, kernel_size=1, padding=0)
|
||||
self.gru = nn.GRU(out_channels,
|
||||
out_channels // 2,
|
||||
direction='bidirectional')
|
||||
|
||||
def forward(self, x):
|
||||
# x: b, c, w, h
|
||||
x = self.conv1(x)
|
||||
x = x.transpose([0, 2, 3, 1]) # b, w, h, c
|
||||
batch_size, w, h, c = x.shape
|
||||
x = x.reshape([-1, h, c]) # b*w, h, c
|
||||
x, _ = self.gru(x)
|
||||
x = x.reshape([-1, w, h, c])
|
||||
x = x.transpose([0, 3, 1, 2])
|
||||
return x
|
||||
@ -148,10 +148,14 @@ def load_pretrained_params(model, path):
|
||||
"The {}.pdparams does not exists!".format(path)
|
||||
|
||||
params = paddle.load(path + '.pdparams')
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
new_state_dict = {}
|
||||
is_float16 = False
|
||||
|
||||
for k1 in params.keys():
|
||||
|
||||
if k1 not in state_dict.keys():
|
||||
logger.warning("The pretrained params {} not in model".format(k1))
|
||||
else:
|
||||
|
||||
@ -78,6 +78,12 @@ def export_single_model(model,
|
||||
shape=[None, 3, 64, 512], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["model_type"] == "sr":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 16, 64], dtype="float32")
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "ViTSTR":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
@ -195,6 +201,9 @@ def main():
|
||||
else: # base rec model
|
||||
config["Architecture"]["Head"]["out_channels"] = char_num
|
||||
|
||||
# for sr algorithm
|
||||
if config["Architecture"]["model_type"] == "sr":
|
||||
config['Architecture']["Transform"]['infer_mode'] = True
|
||||
model = build_model(config["Architecture"])
|
||||
load_model(config, model, model_type=config['Architecture']["model_type"])
|
||||
model.eval()
|
||||
|
||||
155
tools/infer/predict_sr.py
Executable file
155
tools/infer/predict_sr.py
Executable file
@ -0,0 +1,155 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
from PIL import Image
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, __dir__)
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import traceback
|
||||
import paddle
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextSR(object):
|
||||
def __init__(self, args):
|
||||
self.sr_image_shape = [int(v) for v in args.sr_image_shape.split(",")]
|
||||
self.sr_batch_num = args.sr_batch_num
|
||||
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'sr', logger)
|
||||
self.benchmark = args.benchmark
|
||||
if args.benchmark:
|
||||
import auto_log
|
||||
pid = os.getpid()
|
||||
gpu_id = utility.get_infer_gpuid()
|
||||
self.autolog = auto_log.AutoLogger(
|
||||
model_name="sr",
|
||||
model_precision=args.precision,
|
||||
batch_size=args.sr_batch_num,
|
||||
data_shape="dynamic",
|
||||
save_path=None, #args.save_log_path,
|
||||
inference_config=self.config,
|
||||
pids=pid,
|
||||
process_name=None,
|
||||
gpu_ids=gpu_id if args.use_gpu else None,
|
||||
time_keys=[
|
||||
'preprocess_time', 'inference_time', 'postprocess_time'
|
||||
],
|
||||
warmup=0,
|
||||
logger=logger)
|
||||
|
||||
def resize_norm_img(self, img):
|
||||
imgC, imgH, imgW = self.sr_image_shape
|
||||
img = img.resize((imgW // 2, imgH // 2), Image.BICUBIC)
|
||||
img_numpy = np.array(img).astype("float32")
|
||||
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
|
||||
return img_numpy
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
batch_num = self.sr_batch_num
|
||||
st = time.time()
|
||||
st = time.time()
|
||||
all_result = [] * img_num
|
||||
if self.benchmark:
|
||||
self.autolog.times.start()
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
imgC, imgH, imgW = self.sr_image_shape
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[ino])
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if len(outputs) != 1:
|
||||
preds = outputs
|
||||
else:
|
||||
preds = outputs[0]
|
||||
all_result.append(outputs)
|
||||
if self.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
return all_result, time.time() - st
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
text_recognizer = TextSR(args)
|
||||
valid_image_file_list = []
|
||||
img_list = []
|
||||
|
||||
# warmup 2 times
|
||||
if args.warmup:
|
||||
img = np.random.uniform(0, 255, [16, 64, 3]).astype(np.uint8)
|
||||
for i in range(2):
|
||||
res = text_recognizer([img] * int(args.sr_batch_num))
|
||||
|
||||
for image_file in image_file_list:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = Image.open(image_file).convert("RGB")
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
valid_image_file_list.append(image_file)
|
||||
img_list.append(img)
|
||||
try:
|
||||
preds, _ = text_recognizer(img_list)
|
||||
for beg_no in range(len(preds)):
|
||||
sr_img = preds[beg_no][1]
|
||||
lr_img = preds[beg_no][0]
|
||||
for i in (range(sr_img.shape[0])):
|
||||
fm_sr = (sr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
|
||||
fm_lr = (lr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
|
||||
img_name_pure = os.path.split(valid_image_file_list[
|
||||
beg_no * args.sr_batch_num + i])[-1]
|
||||
cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
|
||||
fm_sr[:, :, ::-1])
|
||||
logger.info("The visualized image saved in infer_result/sr_{}".
|
||||
format(img_name_pure))
|
||||
|
||||
except Exception as E:
|
||||
logger.info(traceback.format_exc())
|
||||
logger.info(E)
|
||||
exit()
|
||||
if args.benchmark:
|
||||
text_recognizer.autolog.report()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(utility.parse_args())
|
||||
@ -121,6 +121,11 @@ def init_args():
|
||||
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||
parser.add_argument("--warmup", type=str2bool, default=False)
|
||||
|
||||
# SR parmas
|
||||
parser.add_argument("--sr_model_dir", type=str)
|
||||
parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
|
||||
parser.add_argument("--sr_batch_num", type=int, default=1)
|
||||
|
||||
#
|
||||
parser.add_argument(
|
||||
"--draw_img_save_dir", type=str, default="./inference_results")
|
||||
@ -156,6 +161,8 @@ def create_predictor(args, mode, logger):
|
||||
model_dir = args.table_model_dir
|
||||
elif mode == 'ser':
|
||||
model_dir = args.ser_model_dir
|
||||
elif mode == "sr":
|
||||
model_dir = args.sr_model_dir
|
||||
else:
|
||||
model_dir = args.e2e_model_dir
|
||||
|
||||
|
||||
100
tools/infer_sr.py
Executable file
100
tools/infer_sr.py
Executable file
@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, __dir__)
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import paddle
|
||||
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
|
||||
|
||||
def main():
|
||||
global_config = config['Global']
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# sr transform
|
||||
config['Architecture']["Transform"]['infer_mode'] = True
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
load_model(config, model)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
for op in config['Eval']['dataset']['transforms']:
|
||||
op_name = list(op)[0]
|
||||
if 'Label' in op_name:
|
||||
continue
|
||||
elif op_name in ['SRResize']:
|
||||
op[op_name]['infer_mode'] = True
|
||||
elif op_name == 'KeepKeys':
|
||||
op[op_name]['keep_keys'] = ['imge_lr']
|
||||
transforms.append(op)
|
||||
global_config['infer_mode'] = True
|
||||
ops = create_operators(transforms, global_config)
|
||||
|
||||
save_res_path = config['Global'].get('save_res_path', "./infer_result")
|
||||
if not os.path.exists(os.path.dirname(save_res_path)):
|
||||
os.makedirs(os.path.dirname(save_res_path))
|
||||
|
||||
model.eval()
|
||||
for file in get_image_file_list(config['Global']['infer_img']):
|
||||
logger.info("infer_img: {}".format(file))
|
||||
img = Image.open(file).convert("RGB")
|
||||
data = {'image_lr': img}
|
||||
batch = transform(data, ops)
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
|
||||
preds = model(images)
|
||||
sr_img = preds["sr_img"][0]
|
||||
lr_img = preds["lr_img"][0]
|
||||
fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
|
||||
fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
|
||||
img_name_pure = os.path.split(file)[-1]
|
||||
cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
|
||||
fm_sr[:, :, ::-1])
|
||||
logger.info("The visualized image saved in infer_result/sr_{}".format(
|
||||
img_name_pure))
|
||||
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
main()
|
||||
@ -25,6 +25,8 @@ import datetime
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
import numpy as np
|
||||
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||
|
||||
from ppocr.utils.stats import TrainingStats
|
||||
@ -262,6 +264,7 @@ def train(config,
|
||||
config, 'Train', device, logger, seed=epoch)
|
||||
max_iter = len(train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(train_dataloader)
|
||||
|
||||
for idx, batch in enumerate(train_dataloader):
|
||||
profiler.add_profiler_step(profiler_options)
|
||||
train_reader_cost += time.time() - reader_start
|
||||
@ -289,7 +292,7 @@ def train(config,
|
||||
else:
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie", 'vqa']:
|
||||
elif model_type in ["kie", 'vqa', 'sr']:
|
||||
preds = model(batch)
|
||||
else:
|
||||
preds = model(images)
|
||||
@ -297,11 +300,12 @@ def train(config,
|
||||
avg_loss = loss['loss']
|
||||
avg_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
optimizer.clear_grad()
|
||||
|
||||
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
|
||||
batch = [item.numpy() for item in batch]
|
||||
if model_type in ['kie']:
|
||||
if model_type in ['kie', 'sr']:
|
||||
eval_class(preds, batch)
|
||||
elif model_type in ['table']:
|
||||
post_result = post_process_class(preds, batch)
|
||||
@ -347,8 +351,8 @@ def train(config,
|
||||
len(train_dataloader) - idx - 1) * eta_meter.avg
|
||||
eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
|
||||
'{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
|
||||
'ips: {:.5f} samples/s, eta: {}'.format(
|
||||
'{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
|
||||
'ips: {:.5f} samples/s, eta: {}'.format(
|
||||
epoch, epoch_num, global_step, logs,
|
||||
train_reader_cost / print_batch_step,
|
||||
train_batch_cost / print_batch_step,
|
||||
@ -480,12 +484,13 @@ def eval(model,
|
||||
leave=True)
|
||||
max_iter = len(valid_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(valid_dataloader)
|
||||
sum_images = 0
|
||||
for idx, batch in enumerate(valid_dataloader):
|
||||
if idx >= max_iter:
|
||||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
|
||||
|
||||
# use amp
|
||||
if scaler:
|
||||
with paddle.amp.auto_cast(level='O2'):
|
||||
@ -493,6 +498,20 @@ def eval(model,
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie", 'vqa']:
|
||||
preds = model(batch)
|
||||
elif model_type in ['sr']:
|
||||
preds = model(batch)
|
||||
sr_img = preds["sr_img"]
|
||||
lr_img = preds["lr_img"]
|
||||
|
||||
for i in (range(sr_img.shape[0])):
|
||||
fm_sr = (sr_img[i].numpy() * 255).transpose(
|
||||
1, 2, 0).astype(np.uint8)
|
||||
fm_lr = (lr_img[i].numpy() * 255).transpose(
|
||||
1, 2, 0).astype(np.uint8)
|
||||
cv2.imwrite("output/images/{}_{}_sr.jpg".format(sum_images,
|
||||
i), fm_sr)
|
||||
cv2.imwrite("output/images/{}_{}_lr.jpg".format(sum_images,
|
||||
i), fm_lr)
|
||||
else:
|
||||
preds = model(images)
|
||||
else:
|
||||
@ -500,6 +519,20 @@ def eval(model,
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie", 'vqa']:
|
||||
preds = model(batch)
|
||||
elif model_type in ['sr']:
|
||||
preds = model(batch)
|
||||
sr_img = preds["sr_img"]
|
||||
lr_img = preds["lr_img"]
|
||||
|
||||
for i in (range(sr_img.shape[0])):
|
||||
fm_sr = (sr_img[i].numpy() * 255).transpose(
|
||||
1, 2, 0).astype(np.uint8)
|
||||
fm_lr = (lr_img[i].numpy() * 255).transpose(
|
||||
1, 2, 0).astype(np.uint8)
|
||||
cv2.imwrite("output/images/{}_{}_sr.jpg".format(sum_images,
|
||||
i), fm_sr)
|
||||
cv2.imwrite("output/images/{}_{}_lr.jpg".format(sum_images,
|
||||
i), fm_lr)
|
||||
else:
|
||||
preds = model(images)
|
||||
|
||||
@ -517,12 +550,15 @@ def eval(model,
|
||||
elif model_type in ['table', 'vqa']:
|
||||
post_result = post_process_class(preds, batch_numpy)
|
||||
eval_class(post_result, batch_numpy)
|
||||
elif model_type in ['sr']:
|
||||
eval_class(preds, batch_numpy)
|
||||
else:
|
||||
post_result = post_process_class(preds, batch_numpy[1])
|
||||
eval_class(post_result, batch_numpy)
|
||||
|
||||
pbar.update(1)
|
||||
total_frame += len(images)
|
||||
sum_images += 1
|
||||
# Get final metric,eg. acc or hmean
|
||||
metric = eval_class.get_metric()
|
||||
|
||||
@ -616,7 +652,8 @@ def preprocess(is_train=False):
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN'
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
||||
'Gestalt'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
||||
@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer):
|
||||
config['Loss']['ignore_index'] = char_num - 1
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
model = apply_to_static(model, config, logger)
|
||||
|
||||
# build loss
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user