support export after save model (#13844)

This commit is contained in:
zhangyubo0722 2024-09-25 01:11:01 +08:00 committed by GitHub
parent 3cc4ae9f37
commit 2b51369324
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 528 additions and 321 deletions

381
ppocr/utils/export_model.py Normal file
View File

@ -0,0 +1,381 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import yaml
import json
import copy
import paddle
import paddle.nn as nn
from paddle.jit import to_static
from collections import OrderedDict
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger
def represent_dictionary_order(self, dict_data):
return self.represent_mapping("tag:yaml.org,2002:map", dict_data.items())
def setup_orderdict():
yaml.add_representer(OrderedDict, represent_dictionary_order)
def dump_infer_config(config, path, logger):
setup_orderdict()
infer_cfg = OrderedDict()
if config["Global"].get("hpi_config_path", None):
hpi_config = yaml.safe_load(open(config["Global"]["hpi_config_path"], "r"))
rec_resize_img_dict = next(
(
item
for item in config["Eval"]["dataset"]["transforms"]
if "RecResizeImg" in item
),
None,
)
if rec_resize_img_dict:
dynamic_shapes = [1] + rec_resize_img_dict["RecResizeImg"]["image_shape"]
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
"dynamic_shapes"
]["x"] = [dynamic_shapes for i in range(3)]
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
"max_batch_size"
] = 1
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
hpi_config["Hpi"]["backend_config"]["tensorrt"]["dynamic_shapes"][
"x"
] = [dynamic_shapes for i in range(3)]
hpi_config["Hpi"]["backend_config"]["tensorrt"]["max_batch_size"] = 1
else:
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("paddle_tensorrt")
del hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"]
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt")
del hpi_config["Hpi"]["backend_config"]["tensorrt"]
infer_cfg["Hpi"] = hpi_config["Hpi"]
if config["Global"].get("pdx_model_name", None):
infer_cfg["Global"] = {}
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
postprocess = OrderedDict()
for k, v in config["PostProcess"].items():
postprocess[k] = v
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
tokenizer_file = config["Global"].get("rec_char_dict_path")
if tokenizer_file is not None:
with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
character_dict = json.load(tokenizer_config_handle)
postprocess["character_dict"] = character_dict
else:
if config["Global"].get("character_dict_path") is not None:
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
lines = f.readlines()
character_dict = [line.strip("\n") for line in lines]
postprocess["character_dict"] = character_dict
infer_cfg["PostProcess"] = postprocess
with open(path, "w") as f:
yaml.dump(
infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True
)
logger.info("Export inference config file to {}".format(os.path.join(path)))
def export_single_model(
model, arch_config, save_path, logger, input_shape=None, quanter=None
):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(shape=[None, 1, 64, 256], dtype="float32"),
[
paddle.static.InputSpec(shape=[None, 256, 1], dtype="int64"),
paddle.static.InputSpec(
shape=[None, max_text_length, 1], dtype="int64"
),
paddle.static.InputSpec(
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
),
paddle.static.InputSpec(
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
),
],
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SAR":
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
[paddle.static.InputSpec(shape=[None], dtype="float32")],
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 48, -1], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["SVTR", "CPPD"]:
other_shape = [
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "PREN":
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["model_type"] == "sr":
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 16, 64], dtype="float32")
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ViTSTR":
other_shape = [
paddle.static.InputSpec(shape=[None, 1, 224, 224], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ABINet":
if not input_shape:
input_shape = [3, 32, 128]
other_shape = [
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["NRTR", "SPIN", "RFL"]:
other_shape = [
paddle.static.InputSpec(shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["SATRN"]:
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "VisionLAN":
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "RobustScanner":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
[
paddle.static.InputSpec(
shape=[
None,
],
dtype="float32",
),
paddle.static.InputSpec(shape=[None, max_text_length], dtype="int64"),
],
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "CAN":
other_shape = [
[
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
paddle.static.InputSpec(
shape=[None, arch_config["Head"]["max_text_length"]], dtype="int64"
),
]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "LaTeXOCR":
other_shape = [
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
paddle.static.InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="int64"), # image
]
if "Re" in arch_config["Backbone"]["name"]:
input_spec.extend(
[
paddle.static.InputSpec(
shape=[None, 512, 3], dtype="int64"
), # entities
paddle.static.InputSpec(
shape=[None, None, 2], dtype="int64"
), # relations
]
)
if model.backbone.use_visual_backbone is False:
input_spec.pop(4)
model = to_static(model, input_spec=[input_spec])
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if (
"Transform" in arch_config
and arch_config["Transform"] is not None
and arch_config["Transform"]["name"] == "TPS"
):
logger.info(
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
if arch_config["algorithm"] == "TableMaster":
infer_shape = [3, 480, 480]
if arch_config["algorithm"] == "SLANet":
infer_shape = [3, -1, -1]
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
],
)
if (
arch_config["model_type"] != "sr"
and arch_config["Backbone"]["name"] == "PPLCNetV3"
):
# for rep lcnetv3
for layer in model.sublayers():
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
layer.rep()
if quanter is None:
paddle.jit.save(model, save_path)
else:
quanter.save_quantized_model(model, save_path)
logger.info("inference model is saved to {}".format(save_path))
return
def export(config, base_model=None, save_path=None):
if paddle.distributed.get_rank() != 0:
return
logger = get_logger()
# build post process
post_process_class = build_post_process(config["PostProcess"], config["Global"])
# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
if config["Architecture"]["algorithm"] in [
"Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
if (
config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
): # multi head
out_channels_list = {}
if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
char_num = char_num - 3
out_channels_list["CTCLabelDecode"] = char_num
out_channels_list["SARLabelDecode"] = char_num + 2
out_channels_list["NRTRLabelDecode"] = char_num + 3
config["Architecture"]["Models"][key]["Head"][
"out_channels_list"
] = out_channels_list
else:
config["Architecture"]["Models"][key]["Head"][
"out_channels"
] = char_num
# just one final tensor needs to exported for inference
config["Architecture"]["Models"][key]["return_all_feats"] = False
elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
out_channels_list = {}
char_num = len(getattr(post_process_class, "character"))
if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
if config["PostProcess"]["name"] == "NRTRLabelDecode":
char_num = char_num - 3
out_channels_list["CTCLabelDecode"] = char_num
out_channels_list["SARLabelDecode"] = char_num + 2
out_channels_list["NRTRLabelDecode"] = char_num + 3
config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
# for sr algorithm
if config["Architecture"]["model_type"] == "sr":
config["Architecture"]["Transform"]["infer_mode"] = True
# for latexocr algorithm
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
config["Architecture"]["Backbone"]["is_predict"] = True
config["Architecture"]["Backbone"]["is_export"] = True
config["Architecture"]["Head"]["is_export"] = True
if base_model is not None:
model = base_model
if isinstance(model, paddle.DataParallel):
model = copy.deepcopy(model._layers)
else:
model = copy.deepcopy(model)
else:
model = build_model(config["Architecture"])
load_model(config, model, model_type=config["Architecture"]["model_type"])
model.eval()
if not save_path:
save_path = config["Global"]["save_inference_dir"]
yaml_path = os.path.join(save_path, "inference.yml")
arch_config = config["Architecture"]
if (
arch_config["algorithm"] in ["SVTR", "CPPD"]
and arch_config["Head"]["name"] != "MultiHead"
):
input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
"image_shape"
]
elif arch_config["algorithm"].lower() == "ABINet".lower():
rec_rs = [
c
for c in config["Eval"]["dataset"]["transforms"]
if "ABINetRecResizeImg" in c
]
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
else:
input_shape = None
if arch_config["algorithm"] in [
"Distillation",
]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(
model.model_list[idx], archs[idx], sub_model_save_path, logger
)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(
model, arch_config, save_path, logger, input_shape=input_shape
)
dump_infer_config(config, yaml_path, logger)

View File

@ -20,6 +20,7 @@ import errno
import os import os
import pickle import pickle
import six import six
import json
import paddle import paddle
@ -248,6 +249,15 @@ def save_model(
if prefix == "best_accuracy": if prefix == "best_accuracy":
arch.backbone.model.save_pretrained(best_model_path) arch.backbone.model.save_pretrained(best_model_path)
save_model_info = kwargs.pop("save_model_info", False)
if save_model_info:
with open(os.path.join(model_path, f"{prefix}.info.json"), "w") as f:
json.dump(kwargs, f)
logger.info("Already save model info in {}".format(model_path))
if prefix != "latest":
done_flag = kwargs.pop("done_flag", False)
update_train_results(config, prefix, save_model_info, done_flag=done_flag)
# save metric and config # save metric and config
with open(metric_prefix + ".states", "wb") as f: with open(metric_prefix + ".states", "wb") as f:
pickle.dump(kwargs, f, protocol=2) pickle.dump(kwargs, f, protocol=2)
@ -255,3 +265,80 @@ def save_model(
logger.info("save best model is to {}".format(model_prefix)) logger.info("save best model is to {}".format(model_prefix))
else: else:
logger.info("save model in {}".format(model_prefix)) logger.info("save model in {}".format(model_prefix))
def update_train_results(config, prefix, metric_info, done_flag=False, last_num=5):
if paddle.distributed.get_rank() != 0:
return
assert last_num >= 1
train_results_path = os.path.join(
config["Global"]["save_model_dir"], "train_results.json"
)
save_model_tag = ["pdparams", "pdopt", "pdstates"]
save_inference_tag = ["inference_config", "pdmodel", "pdiparams", "pdiparams.info"]
if os.path.exists(train_results_path):
with open(train_results_path, "r") as fp:
train_results = json.load(fp)
else:
train_results = {}
train_results["model_name"] = config["Global"]["pdx_model_name"]
label_dict_path = os.path.abspath(
config["Global"].get("character_dict_path", "")
)
if label_dict_path != "":
if not os.path.exists(label_dict_path):
label_dict_path = ""
label_dict_path = label_dict_path
train_results["label_dict"] = label_dict_path
train_results["train_log"] = "train.log"
train_results["visualdl_log"] = ""
train_results["config"] = "config.yaml"
train_results["models"] = {}
for i in range(1, last_num + 1):
train_results["models"][f"last_{i}"] = {}
train_results["models"]["best"] = {}
train_results["done_flag"] = done_flag
if "best" in prefix:
if "acc" in metric_info["metric"]:
metric_score = metric_info["metric"]["acc"]
elif "precision" in metric_info["metric"]:
metric_score = metric_info["metric"]["precision"]
else:
raise ValueError("No metric score found.")
train_results["models"]["best"]["score"] = metric_score
for tag in save_model_tag:
train_results["models"]["best"][tag] = os.path.join(
prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states"
)
for tag in save_inference_tag:
train_results["models"]["best"][tag] = os.path.join(
prefix,
"inference",
f"inference.{tag}" if tag != "inference_config" else "inference.yml",
)
else:
for i in range(last_num - 1, 0, -1):
train_results["models"][f"last_{i + 1}"] = train_results["models"][
f"last_{i}"
].copy()
if "acc" in metric_info["metric"]:
metric_score = metric_info["metric"]["acc"]
elif "precision" in metric_info["metric"]:
metric_score = metric_info["metric"]["precision"]
else:
raise ValueError("No metric score found.")
train_results["models"][f"last_{1}"]["score"] = metric_score
for tag in save_model_tag:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states"
)
for tag in save_inference_tag:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix,
"inference",
f"inference.{tag}" if tag != "inference_config" else "inference.yml",
)
with open(train_results_path, "w") as fp:
json.dump(train_results, fp)

View File

@ -21,328 +21,16 @@ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
import argparse import argparse
import yaml
import json
import paddle
from paddle.jit import to_static
from collections import OrderedDict
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
from ppocr.utils.export_model import export
def export_single_model(
model, arch_config, save_path, logger, input_shape=None, quanter=None
):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(shape=[None, 1, 64, 256], dtype="float32"),
[
paddle.static.InputSpec(shape=[None, 256, 1], dtype="int64"),
paddle.static.InputSpec(
shape=[None, max_text_length, 1], dtype="int64"
),
paddle.static.InputSpec(
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
),
paddle.static.InputSpec(
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
),
],
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SAR":
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
[paddle.static.InputSpec(shape=[None], dtype="float32")],
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 48, -1], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["SVTR", "CPPD"]:
other_shape = [
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "PREN":
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["model_type"] == "sr":
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 16, 64], dtype="float32")
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ViTSTR":
other_shape = [
paddle.static.InputSpec(shape=[None, 1, 224, 224], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ABINet":
if not input_shape:
input_shape = [3, 32, 128]
other_shape = [
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["NRTR", "SPIN", "RFL"]:
other_shape = [
paddle.static.InputSpec(shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["SATRN"]:
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "VisionLAN":
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "RobustScanner":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
[
paddle.static.InputSpec(
shape=[
None,
],
dtype="float32",
),
paddle.static.InputSpec(shape=[None, max_text_length], dtype="int64"),
],
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "CAN":
other_shape = [
[
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
paddle.static.InputSpec(
shape=[None, arch_config["Head"]["max_text_length"]], dtype="int64"
),
]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "LaTeXOCR":
other_shape = [
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
paddle.static.InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="int64"), # image
]
if "Re" in arch_config["Backbone"]["name"]:
input_spec.extend(
[
paddle.static.InputSpec(
shape=[None, 512, 3], dtype="int64"
), # entities
paddle.static.InputSpec(
shape=[None, None, 2], dtype="int64"
), # relations
]
)
if model.backbone.use_visual_backbone is False:
input_spec.pop(4)
model = to_static(model, input_spec=[input_spec])
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if (
"Transform" in arch_config
and arch_config["Transform"] is not None
and arch_config["Transform"]["name"] == "TPS"
):
logger.info(
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
if arch_config["algorithm"] == "TableMaster":
infer_shape = [3, 480, 480]
if arch_config["algorithm"] == "SLANet":
infer_shape = [3, -1, -1]
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
],
)
if (
arch_config["model_type"] != "sr"
and arch_config["Backbone"]["name"] == "PPLCNetV3"
):
# for rep lcnetv3
for layer in model.sublayers():
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
layer.rep()
if quanter is None:
paddle.jit.save(model, save_path)
else:
quanter.save_quantized_model(model, save_path)
logger.info("inference model is saved to {}".format(save_path))
return
def represent_dictionary_order(self, dict_data):
return self.represent_mapping("tag:yaml.org,2002:map", dict_data.items())
def setup_orderdict():
yaml.add_representer(OrderedDict, represent_dictionary_order)
def dump_infer_config(config, path, logger):
setup_orderdict()
infer_cfg = OrderedDict()
infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
postprocess = OrderedDict()
for k, v in config["PostProcess"].items():
postprocess[k] = v
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
tokenizer_file = config["Global"].get("rec_char_dict_path")
if tokenizer_file is not None:
with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
character_dict = json.load(tokenizer_config_handle)
postprocess["character_dict"] = character_dict
else:
if config["Global"].get("character_dict_path") is not None:
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
lines = f.readlines()
character_dict = [line.strip("\n") for line in lines]
postprocess["character_dict"] = character_dict
infer_cfg["PostProcess"] = postprocess
with open(path, "w") as f:
yaml.dump(
infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True
)
logger.info("Export inference config file to {}".format(os.path.join(path)))
def main(): def main():
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
config = merge_config(config, FLAGS.opt) config = merge_config(config, FLAGS.opt)
logger = get_logger() # export model
# build post process export(config)
post_process_class = build_post_process(config["PostProcess"], config["Global"])
# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
if config["Architecture"]["algorithm"] in [
"Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
if (
config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
): # multi head
out_channels_list = {}
if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
char_num = char_num - 3
out_channels_list["CTCLabelDecode"] = char_num
out_channels_list["SARLabelDecode"] = char_num + 2
out_channels_list["NRTRLabelDecode"] = char_num + 3
config["Architecture"]["Models"][key]["Head"][
"out_channels_list"
] = out_channels_list
else:
config["Architecture"]["Models"][key]["Head"][
"out_channels"
] = char_num
# just one final tensor needs to exported for inference
config["Architecture"]["Models"][key]["return_all_feats"] = False
elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
out_channels_list = {}
char_num = len(getattr(post_process_class, "character"))
if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
if config["PostProcess"]["name"] == "NRTRLabelDecode":
char_num = char_num - 3
out_channels_list["CTCLabelDecode"] = char_num
out_channels_list["SARLabelDecode"] = char_num + 2
out_channels_list["NRTRLabelDecode"] = char_num + 3
config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
# for sr algorithm
if config["Architecture"]["model_type"] == "sr":
config["Architecture"]["Transform"]["infer_mode"] = True
# for latexocr algorithm
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
config["Architecture"]["Backbone"]["is_predict"] = True
config["Architecture"]["Backbone"]["is_export"] = True
config["Architecture"]["Head"]["is_export"] = True
model = build_model(config["Architecture"])
load_model(config, model, model_type=config["Architecture"]["model_type"])
model.eval()
save_path = config["Global"]["save_inference_dir"]
yaml_path = os.path.join(save_path, "inference.yml")
arch_config = config["Architecture"]
if (
arch_config["algorithm"] in ["SVTR", "CPPD"]
and arch_config["Head"]["name"] != "MultiHead"
):
input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
"image_shape"
]
elif arch_config["algorithm"].lower() == "ABINet".lower():
rec_rs = [
c
for c in config["Eval"]["dataset"]["transforms"]
if "ABINetRecResizeImg" in c
]
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
else:
input_shape = None
if arch_config["algorithm"] in [
"Distillation",
]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(
model.model_list[idx], archs[idx], sub_model_save_path, logger
)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(
model, arch_config, save_path, logger, input_shape=input_shape
)
dump_infer_config(config, yaml_path, logger)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -27,6 +27,7 @@ import paddle.distributed as dist
from tqdm import tqdm from tqdm import tqdm
import cv2 import cv2
import numpy as np import numpy as np
import copy
from argparse import ArgumentParser, RawDescriptionHelpFormatter from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats from ppocr.utils.stats import TrainingStats
@ -36,6 +37,7 @@ from ppocr.utils.logging import get_logger
from ppocr.utils.loggers import WandbLogger, Loggers from ppocr.utils.loggers import WandbLogger, Loggers
from ppocr.utils import profiler from ppocr.utils import profiler
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.utils.export_model import export
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
@ -205,6 +207,7 @@ def train(
eval_batch_epoch = config["Global"].get("eval_batch_epoch", None) eval_batch_epoch = config["Global"].get("eval_batch_epoch", None)
profiler_options = config["profiler_options"] profiler_options = config["profiler_options"]
print_mem_info = config["Global"].get("print_mem_info", True) print_mem_info = config["Global"].get("print_mem_info", True)
uniform_output_enabled = config["Global"].get("uniform_output_enabled", False)
global_step = 0 global_step = 0
if "global_step" in pre_best_model_dict: if "global_step" in pre_best_model_dict:
@ -303,6 +306,7 @@ def train(
) )
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
model.train()
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start train_reader_cost += time.time() - reader_start
if idx >= max_iter: if idx >= max_iter:
@ -484,14 +488,29 @@ def train(
if cur_metric[main_indicator] >= best_model_dict[main_indicator]: if cur_metric[main_indicator] >= best_model_dict[main_indicator]:
best_model_dict.update(cur_metric) best_model_dict.update(cur_metric)
best_model_dict["best_epoch"] = epoch best_model_dict["best_epoch"] = epoch
prefix = "best_accuracy"
if uniform_output_enabled:
export(
config,
model,
os.path.join(save_model_dir, prefix, "inference"),
)
model_info = {"epoch": epoch, "metric": best_model_dict}
else:
model_info = None
save_model( save_model(
model, model,
optimizer, optimizer,
save_model_dir, (
os.path.join(save_model_dir, prefix)
if uniform_output_enabled
else save_model_dir
),
logger, logger,
config, config,
is_best=True, is_best=True,
prefix="best_accuracy", prefix=prefix,
save_model_info=model_info,
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch, epoch=epoch,
global_step=global_step, global_step=global_step,
@ -520,14 +539,25 @@ def train(
reader_start = time.time() reader_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
prefix = "latest"
if uniform_output_enabled:
export(config, model, os.path.join(save_model_dir, prefix, "inference"))
model_info = {"epoch": epoch, "metric": best_model_dict}
else:
model_info = None
save_model( save_model(
model, model,
optimizer, optimizer,
save_model_dir, (
os.path.join(save_model_dir, prefix)
if uniform_output_enabled
else save_model_dir
),
logger, logger,
config, config,
is_best=False, is_best=False,
prefix="latest", prefix=prefix,
save_model_info=model_info,
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch, epoch=epoch,
global_step=global_step, global_step=global_step,
@ -537,17 +567,29 @@ def train(
log_writer.log_model(is_best=False, prefix="latest") log_writer.log_model(is_best=False, prefix="latest")
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
prefix = "iter_epoch_{}".format(epoch)
if uniform_output_enabled:
export(config, model, os.path.join(save_model_dir, prefix, "inference"))
model_info = {"epoch": epoch, "metric": best_model_dict}
else:
model_info = None
save_model( save_model(
model, model,
optimizer, optimizer,
save_model_dir, (
os.path.join(save_model_dir, prefix)
if uniform_output_enabled
else save_model_dir
),
logger, logger,
config, config,
is_best=False, is_best=False,
prefix="iter_epoch_{}".format(epoch), prefix=prefix,
save_model_info=model_info,
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch, epoch=epoch,
global_step=global_step, global_step=global_step,
done_flag=epoch == config["Global"]["epoch_num"],
) )
if log_writer is not None: if log_writer is not None:
log_writer.log_model( log_writer.log_model(

View File

@ -166,6 +166,15 @@ def main(config, device, logger, vdl_writer, seed):
amp_dtype = config["Global"].get("amp_dtype", "float16") amp_dtype = config["Global"].get("amp_dtype", "float16")
amp_custom_black_list = config["Global"].get("amp_custom_black_list", []) amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
amp_custom_white_list = config["Global"].get("amp_custom_white_list", []) amp_custom_white_list = config["Global"].get("amp_custom_white_list", [])
if os.path.exists(
os.path.join(config["Global"]["save_model_dir"], "train_results.json")
):
try:
os.remove(
os.path.join(config["Global"]["save_model_dir"], "train_results.json")
)
except:
pass
if use_amp: if use_amp:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {
"FLAGS_max_inplace_grad_add": 8, "FLAGS_max_inplace_grad_add": 8,