mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-02 18:59:20 +00:00
Merge remote-tracking branch 'origin/dygraph' into dygraph
This commit is contained in:
commit
e83d595502
@ -8,7 +8,7 @@ Global:
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
pretrained_model: ./pretrain_models/ch_ppocr_mobile_v2.1_det_distill_train/best_accuracy
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
@ -19,8 +19,22 @@ Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Teacher:
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
@ -37,7 +51,6 @@ Architecture:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
@ -54,23 +67,7 @@ Architecture:
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
|
||||
132
configs/det/ch_ppocr_v2.1/ch_det_mv3_db_v2.1_student.yml
Normal file
132
configs/det/ch_ppocr_v2.1/ch_det_mv3_db_v2.1_student.yml
Normal file
@ -0,0 +1,132 @@
|
||||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 400]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/student.pdparams
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DBPostProcess
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
||||
@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
||||
from ppocr.utils.utility import print_dict
|
||||
import tools.program as program
|
||||
|
||||
@ -60,7 +60,7 @@ def main():
|
||||
else:
|
||||
model_type = None
|
||||
|
||||
best_model_dict = init_model(config, model)
|
||||
best_model_dict = load_dygraph_params(config, model, logger, None)
|
||||
if len(best_model_dict):
|
||||
logger.info('metric in ckpt ***************')
|
||||
for k, v in best_model_dict.items():
|
||||
@ -71,7 +71,7 @@ def main():
|
||||
|
||||
# start eval
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, use_srn)
|
||||
eval_class, model_type, use_srn)
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
||||
@ -34,23 +34,21 @@ import paddle
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
|
||||
|
||||
def draw_det_res(dt_boxes, config, img, img_name):
|
||||
def draw_det_res(dt_boxes, config, img, img_name, save_path):
|
||||
if len(dt_boxes) > 0:
|
||||
import cv2
|
||||
src_im = img
|
||||
for box in dt_boxes:
|
||||
box = box.astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
||||
save_det_path = os.path.dirname(config['Global'][
|
||||
'save_res_path']) + "/det_results/"
|
||||
if not os.path.exists(save_det_path):
|
||||
os.makedirs(save_det_path)
|
||||
save_path = os.path.join(save_det_path, os.path.basename(img_name))
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
save_path = os.path.join(save_path, os.path.basename(img_name))
|
||||
cv2.imwrite(save_path, src_im)
|
||||
logger.info("The detected Image saved in {}".format(save_path))
|
||||
|
||||
@ -61,8 +59,7 @@ def main():
|
||||
# build model
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model)
|
||||
|
||||
_ = load_dygraph_params(config, model, logger, None)
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'])
|
||||
|
||||
@ -96,17 +93,41 @@ def main():
|
||||
images = paddle.to_tensor(images)
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds, shape_list)
|
||||
boxes = post_result[0]['points']
|
||||
# write result
|
||||
|
||||
src_img = cv2.imread(file)
|
||||
|
||||
dt_boxes_json = []
|
||||
for box in boxes:
|
||||
tmp_json = {"transcription": ""}
|
||||
tmp_json['points'] = box.tolist()
|
||||
dt_boxes_json.append(tmp_json)
|
||||
# parser boxes if post_result is dict
|
||||
if isinstance(post_result, dict):
|
||||
det_box_json = {}
|
||||
for k in post_result.keys():
|
||||
boxes = post_result[k][0]['points']
|
||||
dt_boxes_list = []
|
||||
for box in boxes:
|
||||
tmp_json = {"transcription": ""}
|
||||
tmp_json['points'] = box.tolist()
|
||||
dt_boxes_list.append(tmp_json)
|
||||
det_box_json[k] = dt_boxes_list
|
||||
save_det_path = os.path.dirname(config['Global'][
|
||||
'save_res_path']) + "/det_results_{}/".format(k)
|
||||
draw_det_res(boxes, config, src_img, file, save_det_path)
|
||||
else:
|
||||
boxes = post_result[0]['points']
|
||||
dt_boxes_json = []
|
||||
# write result
|
||||
for box in boxes:
|
||||
tmp_json = {"transcription": ""}
|
||||
tmp_json['points'] = box.tolist()
|
||||
dt_boxes_json.append(tmp_json)
|
||||
save_det_path = os.path.dirname(config['Global'][
|
||||
'save_res_path']) + "/det_results/"
|
||||
draw_det_res(boxes, config, src_img, file, save_det_path)
|
||||
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
|
||||
fout.write(otstr.encode())
|
||||
src_img = cv2.imread(file)
|
||||
draw_det_res(boxes, config, src_img, file)
|
||||
|
||||
save_det_path = os.path.dirname(config['Global'][
|
||||
'save_res_path']) + "/det_results/"
|
||||
draw_det_res(boxes, config, src_img, file, save_det_path)
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user