mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-02 02:39:16 +00:00
Toward Devkit Consistency (#10150)
* Accommodate UAPI * Fix signal handler * Save model.pdopt * Change variable name * Update vdl dir
This commit is contained in:
parent
15abbcc41e
commit
2d44a71b20
@ -25,7 +25,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
import paddle
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.data import build_dataloader, set_signal_handlers
|
||||
from ppocr.modeling.architectures import build_model
|
||||
|
||||
from ppocr.postprocess import build_post_process
|
||||
@ -39,6 +39,7 @@ def main(config, device, logger, vdl_writer):
|
||||
global_config = config['Global']
|
||||
|
||||
# build dataloader
|
||||
set_signal_handlers()
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
# build post process
|
||||
|
||||
@ -26,7 +26,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.data import build_dataloader, set_signal_handlers
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.losses import build_loss
|
||||
from ppocr.optimizer import build_optimizer
|
||||
@ -57,6 +57,7 @@ def main(config, device, logger, vdl_writer):
|
||||
global_config = config['Global']
|
||||
|
||||
# build dataloader
|
||||
set_signal_handlers()
|
||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||
if config['Eval']:
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
@ -34,7 +34,7 @@ from tools.program import load_config, merge_config, ArgsParser
|
||||
from ppocr.metrics import build_metric
|
||||
import tools.program as program
|
||||
from paddleslim.dygraph.quant import QAT
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.data import build_dataloader, set_signal_handlers
|
||||
from tools.export_model import export_single_model
|
||||
|
||||
|
||||
@ -134,6 +134,7 @@ def main():
|
||||
eval_class = build_metric(config['Metric'])
|
||||
|
||||
# build dataloader
|
||||
set_signal_handlers()
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
|
||||
@ -31,7 +31,7 @@ import paddle.distributed as dist
|
||||
|
||||
paddle.seed(2)
|
||||
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.data import build_dataloader, set_signal_handlers
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.losses import build_loss
|
||||
from ppocr.optimizer import build_optimizer
|
||||
@ -95,6 +95,7 @@ def main(config, device, logger, vdl_writer):
|
||||
global_config = config['Global']
|
||||
|
||||
# build dataloader
|
||||
set_signal_handlers()
|
||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||
if config['Eval']:
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
@ -31,7 +31,7 @@ import paddle.distributed as dist
|
||||
|
||||
paddle.seed(2)
|
||||
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.data import build_dataloader, set_signal_handlers
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.losses import build_loss
|
||||
from ppocr.optimizer import build_optimizer
|
||||
@ -117,6 +117,7 @@ def main(config, device, logger, vdl_writer):
|
||||
global_config = config['Global']
|
||||
|
||||
# build dataloader
|
||||
set_signal_handlers()
|
||||
config['Train']['loader']['num_workers'] = 0
|
||||
is_layoutxlm_ser = config['Architecture']['model_type'] =='kie' and config['Architecture']['Backbone']['name'] == 'LayoutXLMForSer'
|
||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||
|
||||
@ -39,7 +39,7 @@ from ppocr.data.pgnet_dataset import PGDataSet
|
||||
from ppocr.data.pubtab_dataset import PubTabDataSet
|
||||
from ppocr.data.multi_scale_sampler import MultiScaleSampler
|
||||
|
||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||
__all__ = ['build_dataloader', 'transform', 'create_operators', 'set_signal_handlers']
|
||||
|
||||
|
||||
def term_mp(sig_num, frame):
|
||||
@ -51,6 +51,21 @@ def term_mp(sig_num, frame):
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
|
||||
|
||||
def set_signal_handlers():
|
||||
pid = os.getpid()
|
||||
pgid = os.getpgid(os.getpid())
|
||||
# XXX: `term_mp` kills all processes in the process group, which in
|
||||
# some cases includes the parent process of current process and may
|
||||
# cause unexpected results. To solve this problem, we set signal
|
||||
# handlers only when current process is the group leader. In the
|
||||
# future, it would be better to consider killing only descendants of
|
||||
# the current process.
|
||||
if pid == pgid:
|
||||
# support exit using ctrl+c
|
||||
signal.signal(signal.SIGINT, term_mp)
|
||||
signal.signal(signal.SIGTERM, term_mp)
|
||||
|
||||
|
||||
def build_dataloader(config, mode, device, logger, seed=None):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
@ -109,8 +124,4 @@ def build_dataloader(config, mode, device, logger, seed=None):
|
||||
use_shared_memory=use_shared_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
# support exit using ctrl+c
|
||||
signal.signal(signal.SIGINT, term_mp)
|
||||
signal.signal(signal.SIGTERM, term_mp)
|
||||
|
||||
return data_loader
|
||||
|
||||
@ -197,13 +197,26 @@ def save_model(model,
|
||||
"""
|
||||
_mkdir_if_not_exist(model_path, logger)
|
||||
model_prefix = os.path.join(model_path, prefix)
|
||||
|
||||
if prefix == 'best_accuracy':
|
||||
best_model_path = os.path.join(model_path, 'best_model')
|
||||
_mkdir_if_not_exist(best_model_path, logger)
|
||||
|
||||
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
|
||||
if prefix == 'best_accuracy':
|
||||
paddle.save(optimizer.state_dict(),
|
||||
os.path.join(best_model_path, 'model.pdopt'))
|
||||
|
||||
is_nlp_model = config['Architecture']["model_type"] == 'kie' and config[
|
||||
"Architecture"]["algorithm"] not in ["SDMGR"]
|
||||
if is_nlp_model is not True:
|
||||
paddle.save(model.state_dict(), model_prefix + '.pdparams')
|
||||
metric_prefix = model_prefix
|
||||
|
||||
if prefix == 'best_accuracy':
|
||||
paddle.save(model.state_dict(),
|
||||
os.path.join(best_model_path, 'model.pdparams'))
|
||||
|
||||
else: # for kie system, we follow the save/load rules in NLP
|
||||
if config['Global']['distributed']:
|
||||
arch = model._layers
|
||||
@ -213,6 +226,10 @@ def save_model(model,
|
||||
arch = arch.Student
|
||||
arch.backbone.model.save_pretrained(model_prefix)
|
||||
metric_prefix = os.path.join(model_prefix, 'metric')
|
||||
|
||||
if prefix == 'best_accuracy':
|
||||
arch.backbone.model.save_pretrained(best_model_path)
|
||||
|
||||
# save metric and config
|
||||
with open(metric_prefix + '.states', 'wb') as f:
|
||||
pickle.dump(kwargs, f, protocol=2)
|
||||
|
||||
@ -24,7 +24,7 @@ sys.path.insert(0, __dir__)
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
import paddle
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.data import build_dataloader, set_signal_handlers
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
@ -35,6 +35,7 @@ import tools.program as program
|
||||
def main():
|
||||
global_config = config['Global']
|
||||
# build dataloader
|
||||
set_signal_handlers()
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
# build post process
|
||||
|
||||
@ -24,7 +24,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.data import build_dataloader, set_signal_handlers
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import load_model
|
||||
@ -40,6 +40,7 @@ def main():
|
||||
'data_dir']
|
||||
config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
|
||||
'label_file_list']
|
||||
set_signal_handlers()
|
||||
eval_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
# build post process
|
||||
|
||||
@ -40,17 +40,16 @@ import tools.program as program
|
||||
|
||||
|
||||
def draw_det_res(dt_boxes, config, img, img_name, save_path):
|
||||
if len(dt_boxes) > 0:
|
||||
import cv2
|
||||
src_im = img
|
||||
for box in dt_boxes:
|
||||
box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
save_path = os.path.join(save_path, os.path.basename(img_name))
|
||||
cv2.imwrite(save_path, src_im)
|
||||
logger.info("The detected Image saved in {}".format(save_path))
|
||||
import cv2
|
||||
src_im = img
|
||||
for box in dt_boxes:
|
||||
box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
save_path = os.path.join(save_path, os.path.basename(img_name))
|
||||
cv2.imwrite(save_path, src_im)
|
||||
logger.info("The detected Image saved in {}".format(save_path))
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
|
||||
@ -683,7 +683,7 @@ def preprocess(is_train=False):
|
||||
|
||||
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
||||
vdl_writer_path = save_model_dir
|
||||
log_writer = VDLLogger(vdl_writer_path)
|
||||
loggers.append(log_writer)
|
||||
if ('use_wandb' in config['Global'] and
|
||||
|
||||
@ -27,7 +27,7 @@ import yaml
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.data import build_dataloader, set_signal_handlers
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.losses import build_loss
|
||||
from ppocr.optimizer import build_optimizer
|
||||
@ -49,6 +49,7 @@ def main(config, device, logger, vdl_writer):
|
||||
global_config = config['Global']
|
||||
|
||||
# build dataloader
|
||||
set_signal_handlers()
|
||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||
if len(train_dataloader) == 0:
|
||||
logger.error(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user