mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-10 14:53:55 +00:00
add d2s train for slanet and v3 (#9341)
* add d2s train for slanet and v3 * fix bug
This commit is contained in:
parent
623424fce0
commit
2e05d54af8
@ -17,6 +17,7 @@ Global:
|
|||||||
infer_img: doc/imgs_en/img_10.jpg
|
infer_img: doc/imgs_en/img_10.jpg
|
||||||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||||
distributed: true
|
distributed: true
|
||||||
|
d2s_train_image_shape: [3, -1, -1]
|
||||||
|
|
||||||
Architecture:
|
Architecture:
|
||||||
name: DistillationModel
|
name: DistillationModel
|
||||||
|
|||||||
@ -12,6 +12,7 @@ Global:
|
|||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
seed: 2022
|
seed: 2022
|
||||||
infer_img: ppstructure/docs/kie/input/zh_val_42.jpg
|
infer_img: ppstructure/docs/kie/input/zh_val_42.jpg
|
||||||
|
d2s_train_image_shape: [3, 224, 224]
|
||||||
# if you want to predict using the groundtruth ocr info,
|
# if you want to predict using the groundtruth ocr info,
|
||||||
# you can use the following config
|
# you can use the following config
|
||||||
# infer_img: train_data/XFUND/zh_val/val.json
|
# infer_img: train_data/XFUND/zh_val/val.json
|
||||||
|
|||||||
@ -19,6 +19,7 @@ Global:
|
|||||||
use_space_char: true
|
use_space_char: true
|
||||||
distributed: true
|
distributed: true
|
||||||
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
|
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
|
||||||
|
d2s_train_image_shape: [3, 48, -1]
|
||||||
|
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
|
|||||||
@ -21,6 +21,7 @@ Global:
|
|||||||
infer_mode: False
|
infer_mode: False
|
||||||
use_sync_bn: True
|
use_sync_bn: True
|
||||||
save_res_path: 'output/infer'
|
save_res_path: 'output/infer'
|
||||||
|
d2s_train_image_shape: [3, -1, -1]
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
name: Adam
|
name: Adam
|
||||||
|
|||||||
@ -17,6 +17,7 @@ Global:
|
|||||||
infer_mode: false
|
infer_mode: false
|
||||||
max_text_length: &max_text_length 500
|
max_text_length: &max_text_length 500
|
||||||
box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy'
|
box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy'
|
||||||
|
d2s_train_image_shape: [3, 480, 480]
|
||||||
|
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
|
|||||||
@ -38,9 +38,9 @@ def build_model(config):
|
|||||||
def apply_to_static(model, config, logger):
|
def apply_to_static(model, config, logger):
|
||||||
if config["Global"].get("to_static", False) is not True:
|
if config["Global"].get("to_static", False) is not True:
|
||||||
return model
|
return model
|
||||||
assert "image_shape" in config[
|
assert "d2s_train_image_shape" in config[
|
||||||
"Global"], "image_shape must be assigned for static training mode..."
|
"Global"], "d2s_train_image_shape must be assigned for static training mode..."
|
||||||
supported_list = ["DB", "SVTR_LCNet", "TableMaster"]
|
supported_list = ["DB", "SVTR_LCNet", "TableMaster", "LayoutXLM", "SLANet"]
|
||||||
if config["Architecture"]["algorithm"] in ["Distillation"]:
|
if config["Architecture"]["algorithm"] in ["Distillation"]:
|
||||||
algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
|
algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
|
||||||
else:
|
else:
|
||||||
@ -49,7 +49,7 @@ def apply_to_static(model, config, logger):
|
|||||||
|
|
||||||
specs = [
|
specs = [
|
||||||
InputSpec(
|
InputSpec(
|
||||||
[None] + config["Global"]["image_shape"], dtype='float32')
|
[None] + config["Global"]["d2s_train_image_shape"], dtype='float32')
|
||||||
]
|
]
|
||||||
|
|
||||||
if algo == "SVTR_LCNet":
|
if algo == "SVTR_LCNet":
|
||||||
@ -62,7 +62,7 @@ def apply_to_static(model, config, logger):
|
|||||||
[None], dtype='int64'), InputSpec(
|
[None], dtype='int64'), InputSpec(
|
||||||
[None], dtype='float64')
|
[None], dtype='float64')
|
||||||
])
|
])
|
||||||
if algo == "TableMaster":
|
elif algo == "TableMaster":
|
||||||
specs.append(
|
specs.append(
|
||||||
[
|
[
|
||||||
InputSpec(
|
InputSpec(
|
||||||
@ -76,6 +76,34 @@ def apply_to_static(model, config, logger):
|
|||||||
InputSpec(
|
InputSpec(
|
||||||
[None, 6], dtype='float32'),
|
[None, 6], dtype='float32'),
|
||||||
])
|
])
|
||||||
|
elif algo == "LayoutXLM":
|
||||||
|
specs = [[
|
||||||
|
InputSpec(
|
||||||
|
shape=[None, 512], dtype="int64"), # input_ids
|
||||||
|
InputSpec(
|
||||||
|
shape=[None, 512, 4], dtype="int64"), # bbox
|
||||||
|
InputSpec(
|
||||||
|
shape=[None, 512], dtype="int64"), # attention_mask
|
||||||
|
InputSpec(
|
||||||
|
shape=[None, 512], dtype="int64"), # token_type_ids
|
||||||
|
InputSpec(
|
||||||
|
shape=[None, 3, 224, 224], dtype="float32"), # image
|
||||||
|
InputSpec(
|
||||||
|
shape=[None, 512], dtype="int64"), # label
|
||||||
|
]]
|
||||||
|
elif algo == "SLANet":
|
||||||
|
specs.append([
|
||||||
|
InputSpec(
|
||||||
|
[None, config["Global"]["max_text_length"] + 2], dtype='int64'),
|
||||||
|
InputSpec(
|
||||||
|
[None, config["Global"]["max_text_length"] + 2, 4],
|
||||||
|
dtype='float32'),
|
||||||
|
InputSpec(
|
||||||
|
[None, config["Global"]["max_text_length"] + 2, 1],
|
||||||
|
dtype='float32'),
|
||||||
|
InputSpec(
|
||||||
|
[None, 6], dtype='float64'),
|
||||||
|
])
|
||||||
model = to_static(model, input_spec=specs)
|
model = to_static(model, input_spec=specs)
|
||||||
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
|
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
|
||||||
return model
|
return model
|
||||||
|
|||||||
@ -20,6 +20,8 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
|
|
||||||
|
MODELS_DIR = os.path.expanduser("~/.paddleocr/models/")
|
||||||
|
|
||||||
|
|
||||||
def download_with_progressbar(url, save_path):
|
def download_with_progressbar(url, save_path):
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|||||||
@ -17,7 +17,7 @@ norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o
|
|||||||
pact_train:null
|
pact_train:null
|
||||||
fpgm_train:null
|
fpgm_train:null
|
||||||
distill_train:null
|
distill_train:null
|
||||||
null:null
|
to_static_train:Global.to_static=true
|
||||||
null:null
|
null:null
|
||||||
##
|
##
|
||||||
===========================eval_params===========================
|
===========================eval_params===========================
|
||||||
|
|||||||
@ -19,6 +19,7 @@ Global:
|
|||||||
use_space_char: true
|
use_space_char: true
|
||||||
distributed: true
|
distributed: true
|
||||||
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
|
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
|
||||||
|
d2s_train_image_shape: [3, 48, -1]
|
||||||
|
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
|
|||||||
@ -17,7 +17,7 @@ norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_d
|
|||||||
pact_train:null
|
pact_train:null
|
||||||
fpgm_train:null
|
fpgm_train:null
|
||||||
distill_train:null
|
distill_train:null
|
||||||
null:null
|
to_static_train:Global.to_static=true
|
||||||
null:null
|
null:null
|
||||||
##
|
##
|
||||||
===========================eval_params===========================
|
===========================eval_params===========================
|
||||||
|
|||||||
@ -21,6 +21,7 @@ Global:
|
|||||||
infer_mode: False
|
infer_mode: False
|
||||||
use_sync_bn: True
|
use_sync_bn: True
|
||||||
save_res_path: 'output/infer'
|
save_res_path: 'output/infer'
|
||||||
|
d2s_train_image_shape: [3, -1, -1]
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
name: Adam
|
name: Adam
|
||||||
|
|||||||
@ -17,7 +17,7 @@ norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o Global.print
|
|||||||
pact_train:null
|
pact_train:null
|
||||||
fpgm_train:null
|
fpgm_train:null
|
||||||
distill_train:null
|
distill_train:null
|
||||||
null:null
|
to_static_train:Global.to_static=true
|
||||||
null:null
|
null:null
|
||||||
##
|
##
|
||||||
===========================eval_params===========================
|
===========================eval_params===========================
|
||||||
|
|||||||
@ -16,7 +16,7 @@ Global:
|
|||||||
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
|
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
|
||||||
infer_mode: false
|
infer_mode: false
|
||||||
max_text_length: 500
|
max_text_length: 500
|
||||||
image_shape: [3, 480, 480]
|
d2s_train_image_shape: [3, 480, 480]
|
||||||
|
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
|
|||||||
@ -17,7 +17,7 @@ norm_train:tools/train.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_z
|
|||||||
pact_train:null
|
pact_train:null
|
||||||
fpgm_train:null
|
fpgm_train:null
|
||||||
distill_train:null
|
distill_train:null
|
||||||
null:null
|
to_static_train:Global.to_static=true
|
||||||
null:null
|
null:null
|
||||||
##
|
##
|
||||||
===========================eval_params===========================
|
===========================eval_params===========================
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user