Merge pull request #4196 from WenmuZhou/fx_pse

add pse to windows_not_support_list
This commit is contained in:
MissPenguin 2021-09-29 10:43:47 +08:00 committed by GitHub
commit 2e67b9ae8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 16 deletions

View File

@ -18,6 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals
import copy
import platform
__all__ = ['build_post_process']
@ -28,7 +29,10 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .pse_postprocess import PSEPostProcess
if platform.system() != "Windows":
# pse is not support in Windows
from .pse_postprocess import PSEPostProcess
def build_post_process(config, global_config=None):

View File

@ -394,21 +394,6 @@ def preprocess(is_train=False):
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'ASTER'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
if is_train:
# save_config
save_model_dir = config['Global']['save_model_dir']
@ -420,6 +405,28 @@ def preprocess(is_train=False):
else:
log_file = None
logger = get_logger(name='root', log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'ASTER'
]
windows_not_support_list = ['PSE']
if platform.system() == "Windows" and alg in windows_not_support_list:
logger.warning('{} is not support in Windows now'.format(
windows_not_support_list))
sys.exit()
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
if config['Global']['use_visualdl']:
from visualdl import LogWriter
save_model_dir = config['Global']['save_model_dir']