mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-06 03:46:58 +00:00
support export after save model (#13844)
This commit is contained in:
parent
3cc4ae9f37
commit
2b51369324
381
ppocr/utils/export_model.py
Normal file
381
ppocr/utils/export_model.py
Normal file
@ -0,0 +1,381 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
from paddle.jit import to_static
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||||
|
from ppocr.modeling.architectures import build_model
|
||||||
|
from ppocr.postprocess import build_post_process
|
||||||
|
from ppocr.utils.save_load import load_model
|
||||||
|
from ppocr.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
def represent_dictionary_order(self, dict_data):
|
||||||
|
return self.represent_mapping("tag:yaml.org,2002:map", dict_data.items())
|
||||||
|
|
||||||
|
|
||||||
|
def setup_orderdict():
|
||||||
|
yaml.add_representer(OrderedDict, represent_dictionary_order)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_infer_config(config, path, logger):
|
||||||
|
setup_orderdict()
|
||||||
|
infer_cfg = OrderedDict()
|
||||||
|
if config["Global"].get("hpi_config_path", None):
|
||||||
|
hpi_config = yaml.safe_load(open(config["Global"]["hpi_config_path"], "r"))
|
||||||
|
rec_resize_img_dict = next(
|
||||||
|
(
|
||||||
|
item
|
||||||
|
for item in config["Eval"]["dataset"]["transforms"]
|
||||||
|
if "RecResizeImg" in item
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if rec_resize_img_dict:
|
||||||
|
dynamic_shapes = [1] + rec_resize_img_dict["RecResizeImg"]["image_shape"]
|
||||||
|
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
|
||||||
|
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
|
||||||
|
"dynamic_shapes"
|
||||||
|
]["x"] = [dynamic_shapes for i in range(3)]
|
||||||
|
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
|
||||||
|
"max_batch_size"
|
||||||
|
] = 1
|
||||||
|
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
|
||||||
|
hpi_config["Hpi"]["backend_config"]["tensorrt"]["dynamic_shapes"][
|
||||||
|
"x"
|
||||||
|
] = [dynamic_shapes for i in range(3)]
|
||||||
|
hpi_config["Hpi"]["backend_config"]["tensorrt"]["max_batch_size"] = 1
|
||||||
|
else:
|
||||||
|
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
|
||||||
|
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("paddle_tensorrt")
|
||||||
|
del hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"]
|
||||||
|
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
|
||||||
|
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt")
|
||||||
|
del hpi_config["Hpi"]["backend_config"]["tensorrt"]
|
||||||
|
infer_cfg["Hpi"] = hpi_config["Hpi"]
|
||||||
|
if config["Global"].get("pdx_model_name", None):
|
||||||
|
infer_cfg["Global"] = {}
|
||||||
|
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
|
||||||
|
|
||||||
|
infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
|
||||||
|
postprocess = OrderedDict()
|
||||||
|
for k, v in config["PostProcess"].items():
|
||||||
|
postprocess[k] = v
|
||||||
|
|
||||||
|
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
|
||||||
|
tokenizer_file = config["Global"].get("rec_char_dict_path")
|
||||||
|
if tokenizer_file is not None:
|
||||||
|
with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
|
||||||
|
character_dict = json.load(tokenizer_config_handle)
|
||||||
|
postprocess["character_dict"] = character_dict
|
||||||
|
else:
|
||||||
|
if config["Global"].get("character_dict_path") is not None:
|
||||||
|
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
character_dict = [line.strip("\n") for line in lines]
|
||||||
|
postprocess["character_dict"] = character_dict
|
||||||
|
|
||||||
|
infer_cfg["PostProcess"] = postprocess
|
||||||
|
|
||||||
|
with open(path, "w") as f:
|
||||||
|
yaml.dump(
|
||||||
|
infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True
|
||||||
|
)
|
||||||
|
logger.info("Export inference config file to {}".format(os.path.join(path)))
|
||||||
|
|
||||||
|
|
||||||
|
def export_single_model(
|
||||||
|
model, arch_config, save_path, logger, input_shape=None, quanter=None
|
||||||
|
):
|
||||||
|
if arch_config["algorithm"] == "SRN":
|
||||||
|
max_text_length = arch_config["Head"]["max_text_length"]
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 1, 64, 256], dtype="float32"),
|
||||||
|
[
|
||||||
|
paddle.static.InputSpec(shape=[None, 256, 1], dtype="int64"),
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, max_text_length, 1], dtype="int64"
|
||||||
|
),
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
|
||||||
|
),
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "SAR":
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
|
||||||
|
[paddle.static.InputSpec(shape=[None], dtype="float32")],
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 3, 48, -1], dtype="float32"),
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] in ["SVTR", "CPPD"]:
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "PREN":
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["model_type"] == "sr":
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 3, 16, 64], dtype="float32")
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "ViTSTR":
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 1, 224, 224], dtype="float32"),
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "ABINet":
|
||||||
|
if not input_shape:
|
||||||
|
input_shape = [3, 32, 128]
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] in ["NRTR", "SPIN", "RFL"]:
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 1, 32, 100], dtype="float32"),
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] in ["SATRN"]:
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype="float32"),
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "VisionLAN":
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "RobustScanner":
|
||||||
|
max_text_length = arch_config["Head"]["max_text_length"]
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
|
||||||
|
[
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
dtype="float32",
|
||||||
|
),
|
||||||
|
paddle.static.InputSpec(shape=[None, max_text_length], dtype="int64"),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "CAN":
|
||||||
|
other_shape = [
|
||||||
|
[
|
||||||
|
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
|
||||||
|
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, arch_config["Head"]["max_text_length"]], dtype="int64"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "LaTeXOCR":
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
||||||
|
input_spec = [
|
||||||
|
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
|
||||||
|
paddle.static.InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
|
||||||
|
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
|
||||||
|
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
|
||||||
|
paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="int64"), # image
|
||||||
|
]
|
||||||
|
if "Re" in arch_config["Backbone"]["name"]:
|
||||||
|
input_spec.extend(
|
||||||
|
[
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, 512, 3], dtype="int64"
|
||||||
|
), # entities
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, None, 2], dtype="int64"
|
||||||
|
), # relations
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if model.backbone.use_visual_backbone is False:
|
||||||
|
input_spec.pop(4)
|
||||||
|
model = to_static(model, input_spec=[input_spec])
|
||||||
|
else:
|
||||||
|
infer_shape = [3, -1, -1]
|
||||||
|
if arch_config["model_type"] == "rec":
|
||||||
|
infer_shape = [3, 32, -1] # for rec model, H must be 32
|
||||||
|
if (
|
||||||
|
"Transform" in arch_config
|
||||||
|
and arch_config["Transform"] is not None
|
||||||
|
and arch_config["Transform"]["name"] == "TPS"
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
|
||||||
|
)
|
||||||
|
infer_shape[-1] = 100
|
||||||
|
elif arch_config["model_type"] == "table":
|
||||||
|
infer_shape = [3, 488, 488]
|
||||||
|
if arch_config["algorithm"] == "TableMaster":
|
||||||
|
infer_shape = [3, 480, 480]
|
||||||
|
if arch_config["algorithm"] == "SLANet":
|
||||||
|
infer_shape = [3, -1, -1]
|
||||||
|
model = to_static(
|
||||||
|
model,
|
||||||
|
input_spec=[
|
||||||
|
paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
arch_config["model_type"] != "sr"
|
||||||
|
and arch_config["Backbone"]["name"] == "PPLCNetV3"
|
||||||
|
):
|
||||||
|
# for rep lcnetv3
|
||||||
|
for layer in model.sublayers():
|
||||||
|
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
|
||||||
|
layer.rep()
|
||||||
|
|
||||||
|
if quanter is None:
|
||||||
|
paddle.jit.save(model, save_path)
|
||||||
|
else:
|
||||||
|
quanter.save_quantized_model(model, save_path)
|
||||||
|
logger.info("inference model is saved to {}".format(save_path))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def export(config, base_model=None, save_path=None):
|
||||||
|
if paddle.distributed.get_rank() != 0:
|
||||||
|
return
|
||||||
|
logger = get_logger()
|
||||||
|
# build post process
|
||||||
|
post_process_class = build_post_process(config["PostProcess"], config["Global"])
|
||||||
|
|
||||||
|
# build model
|
||||||
|
# for rec algorithm
|
||||||
|
if hasattr(post_process_class, "character"):
|
||||||
|
char_num = len(getattr(post_process_class, "character"))
|
||||||
|
if config["Architecture"]["algorithm"] in [
|
||||||
|
"Distillation",
|
||||||
|
]: # distillation model
|
||||||
|
for key in config["Architecture"]["Models"]:
|
||||||
|
if (
|
||||||
|
config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
|
||||||
|
): # multi head
|
||||||
|
out_channels_list = {}
|
||||||
|
if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
|
||||||
|
char_num = char_num - 2
|
||||||
|
if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
|
||||||
|
char_num = char_num - 3
|
||||||
|
out_channels_list["CTCLabelDecode"] = char_num
|
||||||
|
out_channels_list["SARLabelDecode"] = char_num + 2
|
||||||
|
out_channels_list["NRTRLabelDecode"] = char_num + 3
|
||||||
|
config["Architecture"]["Models"][key]["Head"][
|
||||||
|
"out_channels_list"
|
||||||
|
] = out_channels_list
|
||||||
|
else:
|
||||||
|
config["Architecture"]["Models"][key]["Head"][
|
||||||
|
"out_channels"
|
||||||
|
] = char_num
|
||||||
|
# just one final tensor needs to exported for inference
|
||||||
|
config["Architecture"]["Models"][key]["return_all_feats"] = False
|
||||||
|
elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
|
||||||
|
out_channels_list = {}
|
||||||
|
char_num = len(getattr(post_process_class, "character"))
|
||||||
|
if config["PostProcess"]["name"] == "SARLabelDecode":
|
||||||
|
char_num = char_num - 2
|
||||||
|
if config["PostProcess"]["name"] == "NRTRLabelDecode":
|
||||||
|
char_num = char_num - 3
|
||||||
|
out_channels_list["CTCLabelDecode"] = char_num
|
||||||
|
out_channels_list["SARLabelDecode"] = char_num + 2
|
||||||
|
out_channels_list["NRTRLabelDecode"] = char_num + 3
|
||||||
|
config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
|
||||||
|
else: # base rec model
|
||||||
|
config["Architecture"]["Head"]["out_channels"] = char_num
|
||||||
|
|
||||||
|
# for sr algorithm
|
||||||
|
if config["Architecture"]["model_type"] == "sr":
|
||||||
|
config["Architecture"]["Transform"]["infer_mode"] = True
|
||||||
|
|
||||||
|
# for latexocr algorithm
|
||||||
|
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
|
||||||
|
config["Architecture"]["Backbone"]["is_predict"] = True
|
||||||
|
config["Architecture"]["Backbone"]["is_export"] = True
|
||||||
|
config["Architecture"]["Head"]["is_export"] = True
|
||||||
|
if base_model is not None:
|
||||||
|
model = base_model
|
||||||
|
if isinstance(model, paddle.DataParallel):
|
||||||
|
model = copy.deepcopy(model._layers)
|
||||||
|
else:
|
||||||
|
model = copy.deepcopy(model)
|
||||||
|
else:
|
||||||
|
model = build_model(config["Architecture"])
|
||||||
|
load_model(config, model, model_type=config["Architecture"]["model_type"])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if not save_path:
|
||||||
|
save_path = config["Global"]["save_inference_dir"]
|
||||||
|
yaml_path = os.path.join(save_path, "inference.yml")
|
||||||
|
|
||||||
|
arch_config = config["Architecture"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
arch_config["algorithm"] in ["SVTR", "CPPD"]
|
||||||
|
and arch_config["Head"]["name"] != "MultiHead"
|
||||||
|
):
|
||||||
|
input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
|
||||||
|
"image_shape"
|
||||||
|
]
|
||||||
|
elif arch_config["algorithm"].lower() == "ABINet".lower():
|
||||||
|
rec_rs = [
|
||||||
|
c
|
||||||
|
for c in config["Eval"]["dataset"]["transforms"]
|
||||||
|
if "ABINetRecResizeImg" in c
|
||||||
|
]
|
||||||
|
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
|
||||||
|
else:
|
||||||
|
input_shape = None
|
||||||
|
|
||||||
|
if arch_config["algorithm"] in [
|
||||||
|
"Distillation",
|
||||||
|
]: # distillation model
|
||||||
|
archs = list(arch_config["Models"].values())
|
||||||
|
for idx, name in enumerate(model.model_name_list):
|
||||||
|
sub_model_save_path = os.path.join(save_path, name, "inference")
|
||||||
|
export_single_model(
|
||||||
|
model.model_list[idx], archs[idx], sub_model_save_path, logger
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
save_path = os.path.join(save_path, "inference")
|
||||||
|
export_single_model(
|
||||||
|
model, arch_config, save_path, logger, input_shape=input_shape
|
||||||
|
)
|
||||||
|
dump_infer_config(config, yaml_path, logger)
|
||||||
@ -20,6 +20,7 @@ import errno
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import six
|
import six
|
||||||
|
import json
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
@ -248,6 +249,15 @@ def save_model(
|
|||||||
if prefix == "best_accuracy":
|
if prefix == "best_accuracy":
|
||||||
arch.backbone.model.save_pretrained(best_model_path)
|
arch.backbone.model.save_pretrained(best_model_path)
|
||||||
|
|
||||||
|
save_model_info = kwargs.pop("save_model_info", False)
|
||||||
|
if save_model_info:
|
||||||
|
with open(os.path.join(model_path, f"{prefix}.info.json"), "w") as f:
|
||||||
|
json.dump(kwargs, f)
|
||||||
|
logger.info("Already save model info in {}".format(model_path))
|
||||||
|
if prefix != "latest":
|
||||||
|
done_flag = kwargs.pop("done_flag", False)
|
||||||
|
update_train_results(config, prefix, save_model_info, done_flag=done_flag)
|
||||||
|
|
||||||
# save metric and config
|
# save metric and config
|
||||||
with open(metric_prefix + ".states", "wb") as f:
|
with open(metric_prefix + ".states", "wb") as f:
|
||||||
pickle.dump(kwargs, f, protocol=2)
|
pickle.dump(kwargs, f, protocol=2)
|
||||||
@ -255,3 +265,80 @@ def save_model(
|
|||||||
logger.info("save best model is to {}".format(model_prefix))
|
logger.info("save best model is to {}".format(model_prefix))
|
||||||
else:
|
else:
|
||||||
logger.info("save model in {}".format(model_prefix))
|
logger.info("save model in {}".format(model_prefix))
|
||||||
|
|
||||||
|
|
||||||
|
def update_train_results(config, prefix, metric_info, done_flag=False, last_num=5):
|
||||||
|
if paddle.distributed.get_rank() != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert last_num >= 1
|
||||||
|
train_results_path = os.path.join(
|
||||||
|
config["Global"]["save_model_dir"], "train_results.json"
|
||||||
|
)
|
||||||
|
save_model_tag = ["pdparams", "pdopt", "pdstates"]
|
||||||
|
save_inference_tag = ["inference_config", "pdmodel", "pdiparams", "pdiparams.info"]
|
||||||
|
if os.path.exists(train_results_path):
|
||||||
|
with open(train_results_path, "r") as fp:
|
||||||
|
train_results = json.load(fp)
|
||||||
|
else:
|
||||||
|
train_results = {}
|
||||||
|
train_results["model_name"] = config["Global"]["pdx_model_name"]
|
||||||
|
label_dict_path = os.path.abspath(
|
||||||
|
config["Global"].get("character_dict_path", "")
|
||||||
|
)
|
||||||
|
if label_dict_path != "":
|
||||||
|
if not os.path.exists(label_dict_path):
|
||||||
|
label_dict_path = ""
|
||||||
|
label_dict_path = label_dict_path
|
||||||
|
train_results["label_dict"] = label_dict_path
|
||||||
|
train_results["train_log"] = "train.log"
|
||||||
|
train_results["visualdl_log"] = ""
|
||||||
|
train_results["config"] = "config.yaml"
|
||||||
|
train_results["models"] = {}
|
||||||
|
for i in range(1, last_num + 1):
|
||||||
|
train_results["models"][f"last_{i}"] = {}
|
||||||
|
train_results["models"]["best"] = {}
|
||||||
|
train_results["done_flag"] = done_flag
|
||||||
|
if "best" in prefix:
|
||||||
|
if "acc" in metric_info["metric"]:
|
||||||
|
metric_score = metric_info["metric"]["acc"]
|
||||||
|
elif "precision" in metric_info["metric"]:
|
||||||
|
metric_score = metric_info["metric"]["precision"]
|
||||||
|
else:
|
||||||
|
raise ValueError("No metric score found.")
|
||||||
|
train_results["models"]["best"]["score"] = metric_score
|
||||||
|
for tag in save_model_tag:
|
||||||
|
train_results["models"]["best"][tag] = os.path.join(
|
||||||
|
prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states"
|
||||||
|
)
|
||||||
|
for tag in save_inference_tag:
|
||||||
|
train_results["models"]["best"][tag] = os.path.join(
|
||||||
|
prefix,
|
||||||
|
"inference",
|
||||||
|
f"inference.{tag}" if tag != "inference_config" else "inference.yml",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for i in range(last_num - 1, 0, -1):
|
||||||
|
train_results["models"][f"last_{i + 1}"] = train_results["models"][
|
||||||
|
f"last_{i}"
|
||||||
|
].copy()
|
||||||
|
if "acc" in metric_info["metric"]:
|
||||||
|
metric_score = metric_info["metric"]["acc"]
|
||||||
|
elif "precision" in metric_info["metric"]:
|
||||||
|
metric_score = metric_info["metric"]["precision"]
|
||||||
|
else:
|
||||||
|
raise ValueError("No metric score found.")
|
||||||
|
train_results["models"][f"last_{1}"]["score"] = metric_score
|
||||||
|
for tag in save_model_tag:
|
||||||
|
train_results["models"][f"last_{1}"][tag] = os.path.join(
|
||||||
|
prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states"
|
||||||
|
)
|
||||||
|
for tag in save_inference_tag:
|
||||||
|
train_results["models"][f"last_{1}"][tag] = os.path.join(
|
||||||
|
prefix,
|
||||||
|
"inference",
|
||||||
|
f"inference.{tag}" if tag != "inference_config" else "inference.yml",
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(train_results_path, "w") as fp:
|
||||||
|
json.dump(train_results, fp)
|
||||||
|
|||||||
@ -21,328 +21,16 @@ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import yaml
|
|
||||||
import json
|
|
||||||
import paddle
|
|
||||||
from paddle.jit import to_static
|
|
||||||
from collections import OrderedDict
|
|
||||||
from ppocr.modeling.architectures import build_model
|
|
||||||
from ppocr.postprocess import build_post_process
|
|
||||||
from ppocr.utils.save_load import load_model
|
|
||||||
from ppocr.utils.logging import get_logger
|
|
||||||
from tools.program import load_config, merge_config, ArgsParser
|
from tools.program import load_config, merge_config, ArgsParser
|
||||||
|
from ppocr.utils.export_model import export
|
||||||
|
|
||||||
def export_single_model(
|
|
||||||
model, arch_config, save_path, logger, input_shape=None, quanter=None
|
|
||||||
):
|
|
||||||
if arch_config["algorithm"] == "SRN":
|
|
||||||
max_text_length = arch_config["Head"]["max_text_length"]
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 1, 64, 256], dtype="float32"),
|
|
||||||
[
|
|
||||||
paddle.static.InputSpec(shape=[None, 256, 1], dtype="int64"),
|
|
||||||
paddle.static.InputSpec(
|
|
||||||
shape=[None, max_text_length, 1], dtype="int64"
|
|
||||||
),
|
|
||||||
paddle.static.InputSpec(
|
|
||||||
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
|
|
||||||
),
|
|
||||||
paddle.static.InputSpec(
|
|
||||||
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
|
|
||||||
),
|
|
||||||
],
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] == "SAR":
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
|
|
||||||
[paddle.static.InputSpec(shape=[None], dtype="float32")],
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 3, 48, -1], dtype="float32"),
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] in ["SVTR", "CPPD"]:
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] == "PREN":
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["model_type"] == "sr":
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 3, 16, 64], dtype="float32")
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] == "ViTSTR":
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 1, 224, 224], dtype="float32"),
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] == "ABINet":
|
|
||||||
if not input_shape:
|
|
||||||
input_shape = [3, 32, 128]
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] in ["NRTR", "SPIN", "RFL"]:
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 1, 32, 100], dtype="float32"),
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] in ["SATRN"]:
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype="float32"),
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] == "VisionLAN":
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] == "RobustScanner":
|
|
||||||
max_text_length = arch_config["Head"]["max_text_length"]
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
|
|
||||||
[
|
|
||||||
paddle.static.InputSpec(
|
|
||||||
shape=[
|
|
||||||
None,
|
|
||||||
],
|
|
||||||
dtype="float32",
|
|
||||||
),
|
|
||||||
paddle.static.InputSpec(shape=[None, max_text_length], dtype="int64"),
|
|
||||||
],
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] == "CAN":
|
|
||||||
other_shape = [
|
|
||||||
[
|
|
||||||
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
|
|
||||||
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
|
|
||||||
paddle.static.InputSpec(
|
|
||||||
shape=[None, arch_config["Head"]["max_text_length"]], dtype="int64"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] == "LaTeXOCR":
|
|
||||||
other_shape = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
|
|
||||||
]
|
|
||||||
model = to_static(model, input_spec=other_shape)
|
|
||||||
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
|
||||||
input_spec = [
|
|
||||||
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
|
|
||||||
paddle.static.InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
|
|
||||||
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
|
|
||||||
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
|
|
||||||
paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="int64"), # image
|
|
||||||
]
|
|
||||||
if "Re" in arch_config["Backbone"]["name"]:
|
|
||||||
input_spec.extend(
|
|
||||||
[
|
|
||||||
paddle.static.InputSpec(
|
|
||||||
shape=[None, 512, 3], dtype="int64"
|
|
||||||
), # entities
|
|
||||||
paddle.static.InputSpec(
|
|
||||||
shape=[None, None, 2], dtype="int64"
|
|
||||||
), # relations
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if model.backbone.use_visual_backbone is False:
|
|
||||||
input_spec.pop(4)
|
|
||||||
model = to_static(model, input_spec=[input_spec])
|
|
||||||
else:
|
|
||||||
infer_shape = [3, -1, -1]
|
|
||||||
if arch_config["model_type"] == "rec":
|
|
||||||
infer_shape = [3, 32, -1] # for rec model, H must be 32
|
|
||||||
if (
|
|
||||||
"Transform" in arch_config
|
|
||||||
and arch_config["Transform"] is not None
|
|
||||||
and arch_config["Transform"]["name"] == "TPS"
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
|
|
||||||
)
|
|
||||||
infer_shape[-1] = 100
|
|
||||||
elif arch_config["model_type"] == "table":
|
|
||||||
infer_shape = [3, 488, 488]
|
|
||||||
if arch_config["algorithm"] == "TableMaster":
|
|
||||||
infer_shape = [3, 480, 480]
|
|
||||||
if arch_config["algorithm"] == "SLANet":
|
|
||||||
infer_shape = [3, -1, -1]
|
|
||||||
model = to_static(
|
|
||||||
model,
|
|
||||||
input_spec=[
|
|
||||||
paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
arch_config["model_type"] != "sr"
|
|
||||||
and arch_config["Backbone"]["name"] == "PPLCNetV3"
|
|
||||||
):
|
|
||||||
# for rep lcnetv3
|
|
||||||
for layer in model.sublayers():
|
|
||||||
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
|
|
||||||
layer.rep()
|
|
||||||
|
|
||||||
if quanter is None:
|
|
||||||
paddle.jit.save(model, save_path)
|
|
||||||
else:
|
|
||||||
quanter.save_quantized_model(model, save_path)
|
|
||||||
logger.info("inference model is saved to {}".format(save_path))
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def represent_dictionary_order(self, dict_data):
|
|
||||||
return self.represent_mapping("tag:yaml.org,2002:map", dict_data.items())
|
|
||||||
|
|
||||||
|
|
||||||
def setup_orderdict():
|
|
||||||
yaml.add_representer(OrderedDict, represent_dictionary_order)
|
|
||||||
|
|
||||||
|
|
||||||
def dump_infer_config(config, path, logger):
|
|
||||||
setup_orderdict()
|
|
||||||
infer_cfg = OrderedDict()
|
|
||||||
|
|
||||||
infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
|
|
||||||
postprocess = OrderedDict()
|
|
||||||
for k, v in config["PostProcess"].items():
|
|
||||||
postprocess[k] = v
|
|
||||||
|
|
||||||
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
|
|
||||||
tokenizer_file = config["Global"].get("rec_char_dict_path")
|
|
||||||
if tokenizer_file is not None:
|
|
||||||
with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
|
|
||||||
character_dict = json.load(tokenizer_config_handle)
|
|
||||||
postprocess["character_dict"] = character_dict
|
|
||||||
else:
|
|
||||||
if config["Global"].get("character_dict_path") is not None:
|
|
||||||
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
character_dict = [line.strip("\n") for line in lines]
|
|
||||||
postprocess["character_dict"] = character_dict
|
|
||||||
|
|
||||||
infer_cfg["PostProcess"] = postprocess
|
|
||||||
|
|
||||||
with open(path, "w") as f:
|
|
||||||
yaml.dump(
|
|
||||||
infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True
|
|
||||||
)
|
|
||||||
logger.info("Export inference config file to {}".format(os.path.join(path)))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
FLAGS = ArgsParser().parse_args()
|
FLAGS = ArgsParser().parse_args()
|
||||||
config = load_config(FLAGS.config)
|
config = load_config(FLAGS.config)
|
||||||
config = merge_config(config, FLAGS.opt)
|
config = merge_config(config, FLAGS.opt)
|
||||||
logger = get_logger()
|
# export model
|
||||||
# build post process
|
export(config)
|
||||||
|
|
||||||
post_process_class = build_post_process(config["PostProcess"], config["Global"])
|
|
||||||
|
|
||||||
# build model
|
|
||||||
# for rec algorithm
|
|
||||||
if hasattr(post_process_class, "character"):
|
|
||||||
char_num = len(getattr(post_process_class, "character"))
|
|
||||||
if config["Architecture"]["algorithm"] in [
|
|
||||||
"Distillation",
|
|
||||||
]: # distillation model
|
|
||||||
for key in config["Architecture"]["Models"]:
|
|
||||||
if (
|
|
||||||
config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
|
|
||||||
): # multi head
|
|
||||||
out_channels_list = {}
|
|
||||||
if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
|
|
||||||
char_num = char_num - 2
|
|
||||||
if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
|
|
||||||
char_num = char_num - 3
|
|
||||||
out_channels_list["CTCLabelDecode"] = char_num
|
|
||||||
out_channels_list["SARLabelDecode"] = char_num + 2
|
|
||||||
out_channels_list["NRTRLabelDecode"] = char_num + 3
|
|
||||||
config["Architecture"]["Models"][key]["Head"][
|
|
||||||
"out_channels_list"
|
|
||||||
] = out_channels_list
|
|
||||||
else:
|
|
||||||
config["Architecture"]["Models"][key]["Head"][
|
|
||||||
"out_channels"
|
|
||||||
] = char_num
|
|
||||||
# just one final tensor needs to exported for inference
|
|
||||||
config["Architecture"]["Models"][key]["return_all_feats"] = False
|
|
||||||
elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
|
|
||||||
out_channels_list = {}
|
|
||||||
char_num = len(getattr(post_process_class, "character"))
|
|
||||||
if config["PostProcess"]["name"] == "SARLabelDecode":
|
|
||||||
char_num = char_num - 2
|
|
||||||
if config["PostProcess"]["name"] == "NRTRLabelDecode":
|
|
||||||
char_num = char_num - 3
|
|
||||||
out_channels_list["CTCLabelDecode"] = char_num
|
|
||||||
out_channels_list["SARLabelDecode"] = char_num + 2
|
|
||||||
out_channels_list["NRTRLabelDecode"] = char_num + 3
|
|
||||||
config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
|
|
||||||
else: # base rec model
|
|
||||||
config["Architecture"]["Head"]["out_channels"] = char_num
|
|
||||||
|
|
||||||
# for sr algorithm
|
|
||||||
if config["Architecture"]["model_type"] == "sr":
|
|
||||||
config["Architecture"]["Transform"]["infer_mode"] = True
|
|
||||||
|
|
||||||
# for latexocr algorithm
|
|
||||||
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
|
|
||||||
config["Architecture"]["Backbone"]["is_predict"] = True
|
|
||||||
config["Architecture"]["Backbone"]["is_export"] = True
|
|
||||||
config["Architecture"]["Head"]["is_export"] = True
|
|
||||||
model = build_model(config["Architecture"])
|
|
||||||
load_model(config, model, model_type=config["Architecture"]["model_type"])
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
save_path = config["Global"]["save_inference_dir"]
|
|
||||||
yaml_path = os.path.join(save_path, "inference.yml")
|
|
||||||
|
|
||||||
arch_config = config["Architecture"]
|
|
||||||
|
|
||||||
if (
|
|
||||||
arch_config["algorithm"] in ["SVTR", "CPPD"]
|
|
||||||
and arch_config["Head"]["name"] != "MultiHead"
|
|
||||||
):
|
|
||||||
input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
|
|
||||||
"image_shape"
|
|
||||||
]
|
|
||||||
elif arch_config["algorithm"].lower() == "ABINet".lower():
|
|
||||||
rec_rs = [
|
|
||||||
c
|
|
||||||
for c in config["Eval"]["dataset"]["transforms"]
|
|
||||||
if "ABINetRecResizeImg" in c
|
|
||||||
]
|
|
||||||
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
|
|
||||||
else:
|
|
||||||
input_shape = None
|
|
||||||
|
|
||||||
if arch_config["algorithm"] in [
|
|
||||||
"Distillation",
|
|
||||||
]: # distillation model
|
|
||||||
archs = list(arch_config["Models"].values())
|
|
||||||
for idx, name in enumerate(model.model_name_list):
|
|
||||||
sub_model_save_path = os.path.join(save_path, name, "inference")
|
|
||||||
export_single_model(
|
|
||||||
model.model_list[idx], archs[idx], sub_model_save_path, logger
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
save_path = os.path.join(save_path, "inference")
|
|
||||||
export_single_model(
|
|
||||||
model, arch_config, save_path, logger, input_shape=input_shape
|
|
||||||
)
|
|
||||||
dump_infer_config(config, yaml_path, logger)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import paddle.distributed as dist
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import copy
|
||||||
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||||
|
|
||||||
from ppocr.utils.stats import TrainingStats
|
from ppocr.utils.stats import TrainingStats
|
||||||
@ -36,6 +37,7 @@ from ppocr.utils.logging import get_logger
|
|||||||
from ppocr.utils.loggers import WandbLogger, Loggers
|
from ppocr.utils.loggers import WandbLogger, Loggers
|
||||||
from ppocr.utils import profiler
|
from ppocr.utils import profiler
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader
|
||||||
|
from ppocr.utils.export_model import export
|
||||||
|
|
||||||
|
|
||||||
class ArgsParser(ArgumentParser):
|
class ArgsParser(ArgumentParser):
|
||||||
@ -205,6 +207,7 @@ def train(
|
|||||||
eval_batch_epoch = config["Global"].get("eval_batch_epoch", None)
|
eval_batch_epoch = config["Global"].get("eval_batch_epoch", None)
|
||||||
profiler_options = config["profiler_options"]
|
profiler_options = config["profiler_options"]
|
||||||
print_mem_info = config["Global"].get("print_mem_info", True)
|
print_mem_info = config["Global"].get("print_mem_info", True)
|
||||||
|
uniform_output_enabled = config["Global"].get("uniform_output_enabled", False)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
if "global_step" in pre_best_model_dict:
|
if "global_step" in pre_best_model_dict:
|
||||||
@ -303,6 +306,7 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for idx, batch in enumerate(train_dataloader):
|
for idx, batch in enumerate(train_dataloader):
|
||||||
|
model.train()
|
||||||
profiler.add_profiler_step(profiler_options)
|
profiler.add_profiler_step(profiler_options)
|
||||||
train_reader_cost += time.time() - reader_start
|
train_reader_cost += time.time() - reader_start
|
||||||
if idx >= max_iter:
|
if idx >= max_iter:
|
||||||
@ -484,14 +488,29 @@ def train(
|
|||||||
if cur_metric[main_indicator] >= best_model_dict[main_indicator]:
|
if cur_metric[main_indicator] >= best_model_dict[main_indicator]:
|
||||||
best_model_dict.update(cur_metric)
|
best_model_dict.update(cur_metric)
|
||||||
best_model_dict["best_epoch"] = epoch
|
best_model_dict["best_epoch"] = epoch
|
||||||
|
prefix = "best_accuracy"
|
||||||
|
if uniform_output_enabled:
|
||||||
|
export(
|
||||||
|
config,
|
||||||
|
model,
|
||||||
|
os.path.join(save_model_dir, prefix, "inference"),
|
||||||
|
)
|
||||||
|
model_info = {"epoch": epoch, "metric": best_model_dict}
|
||||||
|
else:
|
||||||
|
model_info = None
|
||||||
save_model(
|
save_model(
|
||||||
model,
|
model,
|
||||||
optimizer,
|
optimizer,
|
||||||
save_model_dir,
|
(
|
||||||
|
os.path.join(save_model_dir, prefix)
|
||||||
|
if uniform_output_enabled
|
||||||
|
else save_model_dir
|
||||||
|
),
|
||||||
logger,
|
logger,
|
||||||
config,
|
config,
|
||||||
is_best=True,
|
is_best=True,
|
||||||
prefix="best_accuracy",
|
prefix=prefix,
|
||||||
|
save_model_info=model_info,
|
||||||
best_model_dict=best_model_dict,
|
best_model_dict=best_model_dict,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
@ -520,14 +539,25 @@ def train(
|
|||||||
|
|
||||||
reader_start = time.time()
|
reader_start = time.time()
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
|
prefix = "latest"
|
||||||
|
if uniform_output_enabled:
|
||||||
|
export(config, model, os.path.join(save_model_dir, prefix, "inference"))
|
||||||
|
model_info = {"epoch": epoch, "metric": best_model_dict}
|
||||||
|
else:
|
||||||
|
model_info = None
|
||||||
save_model(
|
save_model(
|
||||||
model,
|
model,
|
||||||
optimizer,
|
optimizer,
|
||||||
save_model_dir,
|
(
|
||||||
|
os.path.join(save_model_dir, prefix)
|
||||||
|
if uniform_output_enabled
|
||||||
|
else save_model_dir
|
||||||
|
),
|
||||||
logger,
|
logger,
|
||||||
config,
|
config,
|
||||||
is_best=False,
|
is_best=False,
|
||||||
prefix="latest",
|
prefix=prefix,
|
||||||
|
save_model_info=model_info,
|
||||||
best_model_dict=best_model_dict,
|
best_model_dict=best_model_dict,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
@ -537,17 +567,29 @@ def train(
|
|||||||
log_writer.log_model(is_best=False, prefix="latest")
|
log_writer.log_model(is_best=False, prefix="latest")
|
||||||
|
|
||||||
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
|
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
|
||||||
|
prefix = "iter_epoch_{}".format(epoch)
|
||||||
|
if uniform_output_enabled:
|
||||||
|
export(config, model, os.path.join(save_model_dir, prefix, "inference"))
|
||||||
|
model_info = {"epoch": epoch, "metric": best_model_dict}
|
||||||
|
else:
|
||||||
|
model_info = None
|
||||||
save_model(
|
save_model(
|
||||||
model,
|
model,
|
||||||
optimizer,
|
optimizer,
|
||||||
save_model_dir,
|
(
|
||||||
|
os.path.join(save_model_dir, prefix)
|
||||||
|
if uniform_output_enabled
|
||||||
|
else save_model_dir
|
||||||
|
),
|
||||||
logger,
|
logger,
|
||||||
config,
|
config,
|
||||||
is_best=False,
|
is_best=False,
|
||||||
prefix="iter_epoch_{}".format(epoch),
|
prefix=prefix,
|
||||||
|
save_model_info=model_info,
|
||||||
best_model_dict=best_model_dict,
|
best_model_dict=best_model_dict,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
|
done_flag=epoch == config["Global"]["epoch_num"],
|
||||||
)
|
)
|
||||||
if log_writer is not None:
|
if log_writer is not None:
|
||||||
log_writer.log_model(
|
log_writer.log_model(
|
||||||
|
|||||||
@ -166,6 +166,15 @@ def main(config, device, logger, vdl_writer, seed):
|
|||||||
amp_dtype = config["Global"].get("amp_dtype", "float16")
|
amp_dtype = config["Global"].get("amp_dtype", "float16")
|
||||||
amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
|
amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
|
||||||
amp_custom_white_list = config["Global"].get("amp_custom_white_list", [])
|
amp_custom_white_list = config["Global"].get("amp_custom_white_list", [])
|
||||||
|
if os.path.exists(
|
||||||
|
os.path.join(config["Global"]["save_model_dir"], "train_results.json")
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
os.remove(
|
||||||
|
os.path.join(config["Global"]["save_model_dir"], "train_results.json")
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if use_amp:
|
if use_amp:
|
||||||
AMP_RELATED_FLAGS_SETTING = {
|
AMP_RELATED_FLAGS_SETTING = {
|
||||||
"FLAGS_max_inplace_grad_add": 8,
|
"FLAGS_max_inplace_grad_add": 8,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user