mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-23 13:35:56 +00:00
fix errors and add pretrain_model
This commit is contained in:
parent
f698142542
commit
6824db26e7
@ -31,7 +31,7 @@
|
|||||||
|- rgb/ total_text数据集的训练数据
|
|- rgb/ total_text数据集的训练数据
|
||||||
|- gt_0.png
|
|- gt_0.png
|
||||||
| ...
|
| ...
|
||||||
|-poly/ total_text数据集的测试标注
|
|- poly/ total_text数据集的测试标注
|
||||||
|- gt_0.txt
|
|- gt_0.txt
|
||||||
| ...
|
| ...
|
||||||
```
|
```
|
||||||
@ -52,19 +52,11 @@
|
|||||||
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。
|
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。
|
||||||
```shell
|
```shell
|
||||||
cd PaddleOCR/
|
cd PaddleOCR/
|
||||||
下载ResNet50_vd的预训练模型
|
下载ResNet50_vd的动态图预训练模型
|
||||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
|
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams
|
||||||
|
|
||||||
# 解压预训练模型文件,以ResNet50_vd为例
|
./pretrain_models/
|
||||||
tar -xf ./pretrain_models/ResNet50_vd_ssld_pretrained.tar ./pretrain_models/
|
└─ ResNet50_vd_ssld_pretrained.pdparams
|
||||||
|
|
||||||
# 注:正确解压backbone预训练权重文件后,文件夹下包含众多以网络层命名的权重文件,格式如下:
|
|
||||||
./pretrain_models/ResNet50_vd_ssld_pretrained/
|
|
||||||
└─ conv_last_bn_mean
|
|
||||||
└─ conv_last_bn_offset
|
|
||||||
└─ conv_last_bn_scale
|
|
||||||
└─ conv_last_bn_variance
|
|
||||||
└─ ......
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -74,11 +66,9 @@ tar -xf ./pretrain_models/ResNet50_vd_ssld_pretrained.tar ./pretrain_models/
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
# 单机单卡训练 e2e 模型
|
# 单机单卡训练 e2e 模型
|
||||||
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml \
|
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/ResNet50_vd_ssld_pretrained Global.load_static_weights=False
|
||||||
-o Global.pretrain_weights=./pretrain_models/ResNet50_vd_ssld_pretrained/ Global.load_static_weights=True
|
|
||||||
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
|
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml \
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/ResNet50_vd_ssld_pretrained Global.load_static_weights=False
|
||||||
-o Global.pretrain_weights=./pretrain_models/ResNet50_vd_ssld_pretrained/ Global.load_static_weights=True
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -369,9 +369,9 @@ Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
|
|||||||
<a name="PGNet端到端模型推理"></a>
|
<a name="PGNet端到端模型推理"></a>
|
||||||
### 1. PGNet端到端模型推理
|
### 1. PGNet端到端模型推理
|
||||||
#### (1). 四边形文本检测模型(ICDAR2015)
|
#### (1). 四边形文本检测模型(ICDAR2015)
|
||||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
|
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)),可以使用如下命令进行转换:
|
||||||
```
|
```
|
||||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
||||||
```
|
```
|
||||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
|
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
|
||||||
```
|
```
|
||||||
@ -382,15 +382,10 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
|
|||||||

|

|
||||||
|
|
||||||
#### (2). 弯曲文本检测模型(Total-Text)
|
#### (2). 弯曲文本检测模型(Total-Text)
|
||||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
|
和四边形文本检测模型共用一个推理模型
|
||||||
|
|
||||||
```
|
|
||||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
|
||||||
```
|
|
||||||
|
|
||||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
||||||
```
|
```
|
||||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
python3.7 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
||||||
```
|
```
|
||||||
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class PGProcessTrain(object):
|
|||||||
tcl_len,
|
tcl_len,
|
||||||
batch_size=14,
|
batch_size=14,
|
||||||
min_crop_size=24,
|
min_crop_size=24,
|
||||||
min_text_size=10,
|
min_text_size=4,
|
||||||
max_text_size=512,
|
max_text_size=512,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.tcl_len = tcl_len
|
self.tcl_len = tcl_len
|
||||||
@ -197,7 +197,6 @@ class PGProcessTrain(object):
|
|||||||
for selected_poly in selected_polys:
|
for selected_poly in selected_polys:
|
||||||
txts_tmp.append(txts[selected_poly])
|
txts_tmp.append(txts[selected_poly])
|
||||||
txts = txts_tmp
|
txts = txts_tmp
|
||||||
# print(1111)
|
|
||||||
return im[ymin: ymax + 1, xmin: xmax + 1, :], \
|
return im[ymin: ymax + 1, xmin: xmax + 1, :], \
|
||||||
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
||||||
else:
|
else:
|
||||||
@ -309,7 +308,6 @@ class PGProcessTrain(object):
|
|||||||
cv2.fillPoly(direction_map,
|
cv2.fillPoly(direction_map,
|
||||||
quad.round().astype(np.int32)[np.newaxis, :, :],
|
quad.round().astype(np.int32)[np.newaxis, :, :],
|
||||||
direction_label)
|
direction_label)
|
||||||
cv2.imwrite("output/{}.png".format(k), direction_map * 255.0)
|
|
||||||
k += 1
|
k += 1
|
||||||
return direction_map
|
return direction_map
|
||||||
|
|
||||||
|
|||||||
@ -67,10 +67,7 @@ class PGDataSet(Dataset):
|
|||||||
np.array(
|
np.array(
|
||||||
list(poly), dtype=np.float32).reshape(-1, 2))
|
list(poly), dtype=np.float32).reshape(-1, 2))
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
if txt == '###':
|
txt_tags.append(txt == '###')
|
||||||
txt_tags.append(True)
|
|
||||||
else:
|
|
||||||
txt_tags.append(False)
|
|
||||||
|
|
||||||
return np.array(list(map(np.array, text_polys))), \
|
return np.array(list(map(np.array, text_polys))), \
|
||||||
np.array(txt_tags, dtype=np.bool), txts
|
np.array(txt_tags, dtype=np.bool), txts
|
||||||
@ -84,8 +81,8 @@ class PGDataSet(Dataset):
|
|||||||
for ext in [
|
for ext in [
|
||||||
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
|
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
|
||||||
]:
|
]:
|
||||||
if os.path.exists(os.path.join(img_dir, info_list[0] + ext)):
|
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
|
||||||
img_path = os.path.join(img_dir, info_list[0] + ext)
|
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
|
||||||
break
|
break
|
||||||
|
|
||||||
if img_path == '':
|
if img_path == '':
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from paddle import nn
|
|||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from .det_basic_loss import DiceLoss
|
from .det_basic_loss import DiceLoss
|
||||||
from ppocr.utils.e2e_utils.extract_batchsize import *
|
from ppocr.utils.e2e_utils.extract_batchsize import pre_process
|
||||||
|
|
||||||
|
|
||||||
class PGLoss(nn.Layer):
|
class PGLoss(nn.Layer):
|
||||||
|
|||||||
@ -18,8 +18,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
__all__ = ['E2EMetric']
|
__all__ = ['E2EMetric']
|
||||||
|
|
||||||
from ppocr.utils.e2e_metric.Deteval import *
|
from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results
|
||||||
from ppocr.utils.e2e_utils.extract_textpoint import *
|
from ppocr.utils.e2e_utils.extract_textpoint import get_dict
|
||||||
|
|
||||||
|
|
||||||
class E2EMetric(object):
|
class E2EMetric(object):
|
||||||
|
|||||||
@ -7,4 +7,5 @@ opencv-python==4.2.0.32
|
|||||||
tqdm
|
tqdm
|
||||||
numpy
|
numpy
|
||||||
visualdl
|
visualdl
|
||||||
python-Levenshtein
|
python-Levenshtein
|
||||||
|
opencv-contrib-python
|
||||||
@ -34,7 +34,7 @@ from ppocr.postprocess import build_post_process
|
|||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class TextE2e(object):
|
class TextE2E(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.e2e_algorithm = args.e2e_algorithm
|
self.e2e_algorithm = args.e2e_algorithm
|
||||||
@ -130,7 +130,7 @@ class TextE2e(object):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = utility.parse_args()
|
args = utility.parse_args()
|
||||||
image_file_list = get_image_file_list(args.image_dir)
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
text_detector = TextE2e(args)
|
text_detector = TextE2E(args)
|
||||||
count = 0
|
count = 0
|
||||||
total_time = 0
|
total_time = 0
|
||||||
draw_img_save = "./inference_results"
|
draw_img_save = "./inference_results"
|
||||||
@ -151,7 +151,7 @@ if __name__ == "__main__":
|
|||||||
src_im = utility.draw_e2e_res(points, strs, image_file)
|
src_im = utility.draw_e2e_res(points, strs, image_file)
|
||||||
img_name_pure = os.path.split(image_file)[-1]
|
img_name_pure = os.path.split(image_file)[-1]
|
||||||
img_path = os.path.join(draw_img_save,
|
img_path = os.path.join(draw_img_save,
|
||||||
"e2e_res_{}".format(img_name_pure))
|
"e2e_res_{}_pgnet".format(img_name_pure))
|
||||||
cv2.imwrite(img_path, src_im)
|
cv2.imwrite(img_path, src_im)
|
||||||
logger.info("The visualized image saved in {}".format(img_path))
|
logger.info("The visualized image saved in {}".format(img_path))
|
||||||
if count > 1:
|
if count > 1:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user