This commit is contained in:
tink2123 2021-09-29 02:48:11 +00:00
parent 560f2f4984
commit 93118497f4
3 changed files with 4 additions and 7 deletions

View File

@ -37,7 +37,7 @@ Optimizer:
Architecture:
model_type: rec
algorithm: seed
algorithm: SEED
Transform:
name: STN_ON
tps_inputsize: [32, 64]

View File

@ -28,9 +28,10 @@ def build_backbone(config, model_type):
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
from .rec_resnet_aster import ResNet_ASTER
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31"
"ResNet31", "ResNet_ASTER"
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
@ -39,9 +40,6 @@ def build_backbone(config, model_type):
from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"]
elif model_type == "seed":
from .rec_resnet_aster import ResNet_ASTER
support_dict = ["ResNet_ASTER"]
else:
raise NotImplementedError

View File

@ -402,8 +402,7 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'ASTER'
]
'SEED']
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)