mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-26 21:24:27 +00:00
add ocr-det v5 model (#15123)
* add ocr detV5 model * add ocr detV5 pretrained model link
This commit is contained in:
parent
0caa3e98de
commit
a836921984
174
configs/det/PP-OCRv5/PP-OCRv5_mobile_det.yml
Normal file
174
configs/det/PP-OCRv5/PP-OCRv5_mobile_det.yml
Normal file
@ -0,0 +1,174 @@
|
||||
Global:
|
||||
model_name: PP-OCRv5_mobile_det # To use static model for inference.
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: &epoch_num 500
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 100
|
||||
save_model_dir: ./output/PP-OCRv5_mobile_det
|
||||
save_epoch_step: 10
|
||||
eval_batch_step:
|
||||
- 0
|
||||
- 1500
|
||||
cal_metric_during_train: false
|
||||
checkpoints:
|
||||
pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/PPLCNetV3_x0_75_ocr_det.pdparams
|
||||
save_inference_dir: null
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||
d2s_train_image_shape: [3, 640, 640]
|
||||
distributed: true
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: PPLCNetV3
|
||||
scale: 0.75
|
||||
det: True
|
||||
Neck:
|
||||
name: RSEFPN
|
||||
out_channels: 96
|
||||
shortcut: True
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
fix_nan: True
|
||||
|
||||
Loss:
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001 #(8*8c)
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 5.0e-05
|
||||
|
||||
PostProcess:
|
||||
name: DBPostProcess
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- CopyPaste: null
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- type: Fliplr
|
||||
args:
|
||||
p: 0.5
|
||||
- type: Affine
|
||||
args:
|
||||
rotate:
|
||||
- -10
|
||||
- 10
|
||||
- type: Resize
|
||||
args:
|
||||
size:
|
||||
- 0.5
|
||||
- 3
|
||||
- EastRandomCropData:
|
||||
size:
|
||||
- 640
|
||||
- 640
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
total_epoch: *epoch_num
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
total_epoch: *epoch_num
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- threshold_map
|
||||
- threshold_mask
|
||||
- shrink_map
|
||||
- shrink_mask
|
||||
loader:
|
||||
shuffle: true
|
||||
drop_last: false
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- DetResizeForTest:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- shape
|
||||
- polys
|
||||
- ignore_tags
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 1
|
||||
num_workers: 2
|
||||
profiler_options: null
|
174
configs/det/PP-OCRv5/PP-OCRv5_server_det.yml
Normal file
174
configs/det/PP-OCRv5/PP-OCRv5_server_det.yml
Normal file
@ -0,0 +1,174 @@
|
||||
Global:
|
||||
model_name: PP-OCRv5_server_det # To use static model for inference.
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: &epoch_num 500
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/PP-OCRv5_server_det
|
||||
save_epoch_step: 10
|
||||
eval_batch_step:
|
||||
- 0
|
||||
- 1500
|
||||
cal_metric_during_train: false
|
||||
checkpoints:
|
||||
pretrained_model: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PPHGNetV2_B4_ocr_det.pdparams
|
||||
save_inference_dir: null
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||
distributed: true
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: PPHGNetV2_B4
|
||||
det: True
|
||||
Neck:
|
||||
name: LKPAN
|
||||
out_channels: 256
|
||||
intracl: true
|
||||
Head:
|
||||
name: PFHeadLocal
|
||||
k: 50
|
||||
mode: "large"
|
||||
|
||||
|
||||
Loss:
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001 #(8*8c)
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 1e-6
|
||||
|
||||
PostProcess:
|
||||
name: DBPostProcess
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- CopyPaste: null
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- type: Fliplr
|
||||
args:
|
||||
p: 0.5
|
||||
- type: Affine
|
||||
args:
|
||||
rotate:
|
||||
- -10
|
||||
- 10
|
||||
- type: Resize
|
||||
args:
|
||||
size:
|
||||
- 0.5
|
||||
- 3
|
||||
- EastRandomCropData:
|
||||
size:
|
||||
- 640
|
||||
- 640
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
total_epoch: *epoch_num
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
total_epoch: *epoch_num
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- threshold_map
|
||||
- threshold_mask
|
||||
- shrink_map
|
||||
- shrink_mask
|
||||
loader:
|
||||
shuffle: true
|
||||
drop_last: false
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- DetResizeForTest:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- shape
|
||||
- polys
|
||||
- ignore_tags
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 1
|
||||
num_workers: 2
|
||||
profiler_options: null
|
@ -28,6 +28,7 @@ def build_backbone(config, model_type):
|
||||
from .det_pp_lcnet_v2 import PPLCNetV2_base
|
||||
from .rec_repvit import RepSVTR_det
|
||||
from .rec_vary_vit import Vary_VIT_B
|
||||
from .rec_pphgnetv2 import PPHGNetV2_B4
|
||||
|
||||
support_dict = [
|
||||
"MobileNetV3",
|
||||
@ -40,6 +41,7 @@ def build_backbone(config, model_type):
|
||||
"PPLCNetV2_base",
|
||||
"RepSVTR_det",
|
||||
"Vary_VIT_B",
|
||||
"PPHGNetV2_B4",
|
||||
]
|
||||
if model_type == "table":
|
||||
from .table_master_resnet import TableResNetExtra
|
||||
|
@ -1381,9 +1381,11 @@ class PPHGNetV2(TheseusLayer):
|
||||
self.dropout = nn.Dropout(p=dropout_prob, mode="downscale_in_infer")
|
||||
|
||||
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
|
||||
self.fc = nn.Linear(
|
||||
self.class_expand if self.use_last_conv else out_channels, self.class_num
|
||||
)
|
||||
if not self.det:
|
||||
self.fc = nn.Linear(
|
||||
self.class_expand if self.use_last_conv else out_channels,
|
||||
self.class_num,
|
||||
)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user