mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-08 05:43:26 +00:00
Toward Devkit Consistency (#10150)
* Accommodate UAPI * Fix signal handler * Save model.pdopt * Change variable name * Update vdl dir
This commit is contained in:
parent
15abbcc41e
commit
2d44a71b20
@ -25,7 +25,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
|||||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader, set_signal_handlers
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
|
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
@ -39,6 +39,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
|
set_signal_handlers()
|
||||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||||
|
|
||||||
# build post process
|
# build post process
|
||||||
|
|||||||
@ -26,7 +26,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.distributed as dist
|
import paddle.distributed as dist
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader, set_signal_handlers
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.losses import build_loss
|
from ppocr.losses import build_loss
|
||||||
from ppocr.optimizer import build_optimizer
|
from ppocr.optimizer import build_optimizer
|
||||||
@ -57,6 +57,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
|
set_signal_handlers()
|
||||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||||
if config['Eval']:
|
if config['Eval']:
|
||||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||||
|
|||||||
@ -34,7 +34,7 @@ from tools.program import load_config, merge_config, ArgsParser
|
|||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
from paddleslim.dygraph.quant import QAT
|
from paddleslim.dygraph.quant import QAT
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader, set_signal_handlers
|
||||||
from tools.export_model import export_single_model
|
from tools.export_model import export_single_model
|
||||||
|
|
||||||
|
|
||||||
@ -134,6 +134,7 @@ def main():
|
|||||||
eval_class = build_metric(config['Metric'])
|
eval_class = build_metric(config['Metric'])
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
|
set_signal_handlers()
|
||||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||||
|
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
|
|||||||
@ -31,7 +31,7 @@ import paddle.distributed as dist
|
|||||||
|
|
||||||
paddle.seed(2)
|
paddle.seed(2)
|
||||||
|
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader, set_signal_handlers
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.losses import build_loss
|
from ppocr.losses import build_loss
|
||||||
from ppocr.optimizer import build_optimizer
|
from ppocr.optimizer import build_optimizer
|
||||||
@ -95,6 +95,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
|
set_signal_handlers()
|
||||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||||
if config['Eval']:
|
if config['Eval']:
|
||||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ import paddle.distributed as dist
|
|||||||
|
|
||||||
paddle.seed(2)
|
paddle.seed(2)
|
||||||
|
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader, set_signal_handlers
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.losses import build_loss
|
from ppocr.losses import build_loss
|
||||||
from ppocr.optimizer import build_optimizer
|
from ppocr.optimizer import build_optimizer
|
||||||
@ -117,6 +117,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
|
set_signal_handlers()
|
||||||
config['Train']['loader']['num_workers'] = 0
|
config['Train']['loader']['num_workers'] = 0
|
||||||
is_layoutxlm_ser = config['Architecture']['model_type'] =='kie' and config['Architecture']['Backbone']['name'] == 'LayoutXLMForSer'
|
is_layoutxlm_ser = config['Architecture']['model_type'] =='kie' and config['Architecture']['Backbone']['name'] == 'LayoutXLMForSer'
|
||||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||||
|
|||||||
@ -39,7 +39,7 @@ from ppocr.data.pgnet_dataset import PGDataSet
|
|||||||
from ppocr.data.pubtab_dataset import PubTabDataSet
|
from ppocr.data.pubtab_dataset import PubTabDataSet
|
||||||
from ppocr.data.multi_scale_sampler import MultiScaleSampler
|
from ppocr.data.multi_scale_sampler import MultiScaleSampler
|
||||||
|
|
||||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
__all__ = ['build_dataloader', 'transform', 'create_operators', 'set_signal_handlers']
|
||||||
|
|
||||||
|
|
||||||
def term_mp(sig_num, frame):
|
def term_mp(sig_num, frame):
|
||||||
@ -51,6 +51,21 @@ def term_mp(sig_num, frame):
|
|||||||
os.killpg(pgid, signal.SIGKILL)
|
os.killpg(pgid, signal.SIGKILL)
|
||||||
|
|
||||||
|
|
||||||
|
def set_signal_handlers():
|
||||||
|
pid = os.getpid()
|
||||||
|
pgid = os.getpgid(os.getpid())
|
||||||
|
# XXX: `term_mp` kills all processes in the process group, which in
|
||||||
|
# some cases includes the parent process of current process and may
|
||||||
|
# cause unexpected results. To solve this problem, we set signal
|
||||||
|
# handlers only when current process is the group leader. In the
|
||||||
|
# future, it would be better to consider killing only descendants of
|
||||||
|
# the current process.
|
||||||
|
if pid == pgid:
|
||||||
|
# support exit using ctrl+c
|
||||||
|
signal.signal(signal.SIGINT, term_mp)
|
||||||
|
signal.signal(signal.SIGTERM, term_mp)
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(config, mode, device, logger, seed=None):
|
def build_dataloader(config, mode, device, logger, seed=None):
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
||||||
@ -109,8 +124,4 @@ def build_dataloader(config, mode, device, logger, seed=None):
|
|||||||
use_shared_memory=use_shared_memory,
|
use_shared_memory=use_shared_memory,
|
||||||
collate_fn=collate_fn)
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
# support exit using ctrl+c
|
|
||||||
signal.signal(signal.SIGINT, term_mp)
|
|
||||||
signal.signal(signal.SIGTERM, term_mp)
|
|
||||||
|
|
||||||
return data_loader
|
return data_loader
|
||||||
|
|||||||
@ -197,13 +197,26 @@ def save_model(model,
|
|||||||
"""
|
"""
|
||||||
_mkdir_if_not_exist(model_path, logger)
|
_mkdir_if_not_exist(model_path, logger)
|
||||||
model_prefix = os.path.join(model_path, prefix)
|
model_prefix = os.path.join(model_path, prefix)
|
||||||
|
|
||||||
|
if prefix == 'best_accuracy':
|
||||||
|
best_model_path = os.path.join(model_path, 'best_model')
|
||||||
|
_mkdir_if_not_exist(best_model_path, logger)
|
||||||
|
|
||||||
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
|
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
|
||||||
|
if prefix == 'best_accuracy':
|
||||||
|
paddle.save(optimizer.state_dict(),
|
||||||
|
os.path.join(best_model_path, 'model.pdopt'))
|
||||||
|
|
||||||
is_nlp_model = config['Architecture']["model_type"] == 'kie' and config[
|
is_nlp_model = config['Architecture']["model_type"] == 'kie' and config[
|
||||||
"Architecture"]["algorithm"] not in ["SDMGR"]
|
"Architecture"]["algorithm"] not in ["SDMGR"]
|
||||||
if is_nlp_model is not True:
|
if is_nlp_model is not True:
|
||||||
paddle.save(model.state_dict(), model_prefix + '.pdparams')
|
paddle.save(model.state_dict(), model_prefix + '.pdparams')
|
||||||
metric_prefix = model_prefix
|
metric_prefix = model_prefix
|
||||||
|
|
||||||
|
if prefix == 'best_accuracy':
|
||||||
|
paddle.save(model.state_dict(),
|
||||||
|
os.path.join(best_model_path, 'model.pdparams'))
|
||||||
|
|
||||||
else: # for kie system, we follow the save/load rules in NLP
|
else: # for kie system, we follow the save/load rules in NLP
|
||||||
if config['Global']['distributed']:
|
if config['Global']['distributed']:
|
||||||
arch = model._layers
|
arch = model._layers
|
||||||
@ -213,6 +226,10 @@ def save_model(model,
|
|||||||
arch = arch.Student
|
arch = arch.Student
|
||||||
arch.backbone.model.save_pretrained(model_prefix)
|
arch.backbone.model.save_pretrained(model_prefix)
|
||||||
metric_prefix = os.path.join(model_prefix, 'metric')
|
metric_prefix = os.path.join(model_prefix, 'metric')
|
||||||
|
|
||||||
|
if prefix == 'best_accuracy':
|
||||||
|
arch.backbone.model.save_pretrained(best_model_path)
|
||||||
|
|
||||||
# save metric and config
|
# save metric and config
|
||||||
with open(metric_prefix + '.states', 'wb') as f:
|
with open(metric_prefix + '.states', 'wb') as f:
|
||||||
pickle.dump(kwargs, f, protocol=2)
|
pickle.dump(kwargs, f, protocol=2)
|
||||||
|
|||||||
@ -24,7 +24,7 @@ sys.path.insert(0, __dir__)
|
|||||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader, set_signal_handlers
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
@ -35,6 +35,7 @@ import tools.program as program
|
|||||||
def main():
|
def main():
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
# build dataloader
|
# build dataloader
|
||||||
|
set_signal_handlers()
|
||||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||||
|
|
||||||
# build post process
|
# build post process
|
||||||
|
|||||||
@ -24,7 +24,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
|
|||||||
sys.path.append(__dir__)
|
sys.path.append(__dir__)
|
||||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader, set_signal_handlers
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.utils.save_load import load_model
|
from ppocr.utils.save_load import load_model
|
||||||
@ -40,6 +40,7 @@ def main():
|
|||||||
'data_dir']
|
'data_dir']
|
||||||
config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
|
config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
|
||||||
'label_file_list']
|
'label_file_list']
|
||||||
|
set_signal_handlers()
|
||||||
eval_dataloader = build_dataloader(config, 'Eval', device, logger)
|
eval_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||||
|
|
||||||
# build post process
|
# build post process
|
||||||
|
|||||||
@ -40,7 +40,6 @@ import tools.program as program
|
|||||||
|
|
||||||
|
|
||||||
def draw_det_res(dt_boxes, config, img, img_name, save_path):
|
def draw_det_res(dt_boxes, config, img, img_name, save_path):
|
||||||
if len(dt_boxes) > 0:
|
|
||||||
import cv2
|
import cv2
|
||||||
src_im = img
|
src_im = img
|
||||||
for box in dt_boxes:
|
for box in dt_boxes:
|
||||||
|
|||||||
@ -683,7 +683,7 @@ def preprocess(is_train=False):
|
|||||||
|
|
||||||
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
|
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
|
||||||
save_model_dir = config['Global']['save_model_dir']
|
save_model_dir = config['Global']['save_model_dir']
|
||||||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
vdl_writer_path = save_model_dir
|
||||||
log_writer = VDLLogger(vdl_writer_path)
|
log_writer = VDLLogger(vdl_writer_path)
|
||||||
loggers.append(log_writer)
|
loggers.append(log_writer)
|
||||||
if ('use_wandb' in config['Global'] and
|
if ('use_wandb' in config['Global'] and
|
||||||
|
|||||||
@ -27,7 +27,7 @@ import yaml
|
|||||||
import paddle
|
import paddle
|
||||||
import paddle.distributed as dist
|
import paddle.distributed as dist
|
||||||
|
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader, set_signal_handlers
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.losses import build_loss
|
from ppocr.losses import build_loss
|
||||||
from ppocr.optimizer import build_optimizer
|
from ppocr.optimizer import build_optimizer
|
||||||
@ -49,6 +49,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
|
set_signal_handlers()
|
||||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||||
if len(train_dataloader) == 0:
|
if len(train_dataloader) == 0:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user