mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-11-04 11:49:14 +00:00 
			
		
		
		
	Merge pull request #1123 from WenmuZhou/dygraph_rc
fix some error and make some change
This commit is contained in:
		
						commit
						c93b4a171d
					
				@ -44,9 +44,9 @@ Optimizer:
 | 
			
		||||
  name: Adam
 | 
			
		||||
  beta1: 0.9
 | 
			
		||||
  beta2: 0.999
 | 
			
		||||
  learning_rate:
 | 
			
		||||
  lr:
 | 
			
		||||
#    name: Cosine
 | 
			
		||||
    lr: 0.001
 | 
			
		||||
    learning_rate: 0.001
 | 
			
		||||
#    warmup_epoch: 0
 | 
			
		||||
  regularizer:
 | 
			
		||||
    name: 'L2'
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ Global:
 | 
			
		||||
  save_model_dir: ./output/rec/mv3_none_bilstm_ctc/
 | 
			
		||||
  save_epoch_step: 3
 | 
			
		||||
  # evaluation is run every 5000 iterations after the 4000th iteration
 | 
			
		||||
  eval_batch_step: [0, 1000]
 | 
			
		||||
  eval_batch_step: [0, 2000]
 | 
			
		||||
  # if pretrained_model is saved in static mode, load_static_weights must set to True
 | 
			
		||||
  cal_metric_during_train: True
 | 
			
		||||
  pretrained_model:
 | 
			
		||||
@ -18,22 +18,19 @@ Global:
 | 
			
		||||
  character_dict_path: 
 | 
			
		||||
  character_type: en
 | 
			
		||||
  max_text_length: 25
 | 
			
		||||
  loss_type: ctc
 | 
			
		||||
  infer_mode: False
 | 
			
		||||
#   use_space_char: True
 | 
			
		||||
 | 
			
		||||
#   use_tps: False
 | 
			
		||||
  use_space_char: False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Optimizer:
 | 
			
		||||
  name: Adam
 | 
			
		||||
  beta1: 0.9
 | 
			
		||||
  beta2: 0.999
 | 
			
		||||
  learning_rate:
 | 
			
		||||
    lr: 0.0005
 | 
			
		||||
  lr:
 | 
			
		||||
    learning_rate: 0.0005
 | 
			
		||||
  regularizer:
 | 
			
		||||
    name: 'L2'
 | 
			
		||||
    factor: 0.00001
 | 
			
		||||
    factor: 0
 | 
			
		||||
 | 
			
		||||
Architecture:
 | 
			
		||||
  model_type: rec
 | 
			
		||||
@ -49,7 +46,7 @@ Architecture:
 | 
			
		||||
    hidden_size: 96
 | 
			
		||||
  Head:
 | 
			
		||||
    name: CTCHead
 | 
			
		||||
    fc_decay: 0.0004
 | 
			
		||||
    fc_decay: 0
 | 
			
		||||
 | 
			
		||||
Loss:
 | 
			
		||||
  name: CTCLoss
 | 
			
		||||
@ -75,8 +72,8 @@ Train:
 | 
			
		||||
      - KeepKeys:
 | 
			
		||||
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
 | 
			
		||||
  loader:
 | 
			
		||||
    shuffle: True
 | 
			
		||||
    batch_size_per_card: 256
 | 
			
		||||
    shuffle: False
 | 
			
		||||
    drop_last: True
 | 
			
		||||
    num_workers: 8
 | 
			
		||||
 | 
			
		||||
@ -97,4 +94,4 @@ Eval:
 | 
			
		||||
    shuffle: False
 | 
			
		||||
    drop_last: False
 | 
			
		||||
    batch_size_per_card: 256
 | 
			
		||||
    num_workers: 2
 | 
			
		||||
    num_workers: 4
 | 
			
		||||
 | 
			
		||||
@ -11,13 +11,9 @@
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import copy
 | 
			
		||||
import numpy as np
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import paddle
 | 
			
		||||
from paddle.io import Dataset
 | 
			
		||||
import time
 | 
			
		||||
import lmdb
 | 
			
		||||
import cv2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -11,13 +11,10 @@
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import copy
 | 
			
		||||
import numpy as np
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import paddle
 | 
			
		||||
from paddle.io import Dataset
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
from .imaug import transform, create_operators
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,8 +23,8 @@ __all__ = ['build_metric']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_metric(config):
 | 
			
		||||
    from .DetMetric import DetMetric
 | 
			
		||||
    from .RecMetric import RecMetric
 | 
			
		||||
    from .det_metric import DetMetric
 | 
			
		||||
    from .rec_metric import RecMetric
 | 
			
		||||
 | 
			
		||||
    support_dict = ['DetMetric', 'RecMetric']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -58,7 +58,7 @@ class Head(nn.Layer):
 | 
			
		||||
            stride=2,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name=name_list[2] + '.w_0',
 | 
			
		||||
                initializer=paddle.nn.initializer.KaimingNormal()),
 | 
			
		||||
                initializer=paddle.nn.initializer.KaimingUniform()),
 | 
			
		||||
            bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2"))
 | 
			
		||||
        self.conv_bn2 = nn.BatchNorm(
 | 
			
		||||
            num_channels=in_channels // 4,
 | 
			
		||||
@ -78,7 +78,7 @@ class Head(nn.Layer):
 | 
			
		||||
            stride=2,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name=name_list[4] + '.w_0',
 | 
			
		||||
                initializer=paddle.nn.initializer.KaimingNormal()),
 | 
			
		||||
                initializer=paddle.nn.initializer.KaimingUniform()),
 | 
			
		||||
            bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,7 @@ class DBFPN(nn.Layer):
 | 
			
		||||
    def __init__(self, in_channels, out_channels, **kwargs):
 | 
			
		||||
        super(DBFPN, self).__init__()
 | 
			
		||||
        self.out_channels = out_channels
 | 
			
		||||
        weight_attr = paddle.nn.initializer.KaimingNormal()
 | 
			
		||||
        weight_attr = paddle.nn.initializer.KaimingUniform()
 | 
			
		||||
 | 
			
		||||
        self.in2_conv = nn.Conv2D(
 | 
			
		||||
            in_channels=in_channels[0],
 | 
			
		||||
@ -97,17 +97,20 @@ class DBFPN(nn.Layer):
 | 
			
		||||
        in3 = self.in3_conv(c3)
 | 
			
		||||
        in2 = self.in2_conv(c2)
 | 
			
		||||
 | 
			
		||||
        out4 = in4 + F.upsample(in5, scale_factor=2, mode="nearest")  # 1/16
 | 
			
		||||
        out3 = in3 + F.upsample(out4, scale_factor=2, mode="nearest")  # 1/8
 | 
			
		||||
        out2 = in2 + F.upsample(out3, scale_factor=2, mode="nearest")  # 1/4
 | 
			
		||||
        out4 = in4 + F.upsample(
 | 
			
		||||
            in5, scale_factor=2, mode="nearest", align_mode=1)  # 1/16
 | 
			
		||||
        out3 = in3 + F.upsample(
 | 
			
		||||
            out4, scale_factor=2, mode="nearest", align_mode=1)  # 1/8
 | 
			
		||||
        out2 = in2 + F.upsample(
 | 
			
		||||
            out3, scale_factor=2, mode="nearest", align_mode=1)  # 1/4
 | 
			
		||||
 | 
			
		||||
        p5 = self.p5_conv(in5)
 | 
			
		||||
        p4 = self.p4_conv(out4)
 | 
			
		||||
        p3 = self.p3_conv(out3)
 | 
			
		||||
        p2 = self.p2_conv(out2)
 | 
			
		||||
        p5 = F.upsample(p5, scale_factor=8, mode="nearest")
 | 
			
		||||
        p4 = F.upsample(p4, scale_factor=4, mode="nearest")
 | 
			
		||||
        p3 = F.upsample(p3, scale_factor=2, mode="nearest")
 | 
			
		||||
        p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
 | 
			
		||||
        p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
 | 
			
		||||
        p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
 | 
			
		||||
 | 
			
		||||
        fuse = paddle.concat([p5, p4, p3, p2], axis=1)
 | 
			
		||||
        return fuse
 | 
			
		||||
 | 
			
		||||
@ -29,7 +29,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
 | 
			
		||||
        lr_name = lr_config.pop('name')
 | 
			
		||||
        lr = getattr(learning_rate, lr_name)(**lr_config)()
 | 
			
		||||
    else:
 | 
			
		||||
        lr = lr_config['lr']
 | 
			
		||||
        lr = lr_config['learning_rate']
 | 
			
		||||
    return lr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -37,8 +37,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
 | 
			
		||||
    from . import regularizer, optimizer
 | 
			
		||||
    config = copy.deepcopy(config)
 | 
			
		||||
    # step1 build lr
 | 
			
		||||
    lr = build_lr_scheduler(
 | 
			
		||||
        config.pop('learning_rate'), epochs, step_each_epoch)
 | 
			
		||||
    lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
 | 
			
		||||
 | 
			
		||||
    # step2 build regularization
 | 
			
		||||
    if 'regularizer' in config and config['regularizer'] is not None:
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
from __future__ import unicode_literals
 | 
			
		||||
 | 
			
		||||
from paddle.optimizer import lr as lr_scheduler
 | 
			
		||||
from paddle.optimizer import lr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Linear(object):
 | 
			
		||||
@ -32,7 +32,7 @@ class Linear(object):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 lr,
 | 
			
		||||
                 learning_rate,
 | 
			
		||||
                 epochs,
 | 
			
		||||
                 step_each_epoch,
 | 
			
		||||
                 end_lr=0.0,
 | 
			
		||||
@ -41,7 +41,7 @@ class Linear(object):
 | 
			
		||||
                 last_epoch=-1,
 | 
			
		||||
                 **kwargs):
 | 
			
		||||
        super(Linear, self).__init__()
 | 
			
		||||
        self.lr = lr
 | 
			
		||||
        self.learning_rate = learning_rate
 | 
			
		||||
        self.epochs = epochs * step_each_epoch
 | 
			
		||||
        self.end_lr = end_lr
 | 
			
		||||
        self.power = power
 | 
			
		||||
@ -49,18 +49,18 @@ class Linear(object):
 | 
			
		||||
        self.warmup_epoch = warmup_epoch * step_each_epoch
 | 
			
		||||
 | 
			
		||||
    def __call__(self):
 | 
			
		||||
        learning_rate = lr_scheduler.PolynomialLR(
 | 
			
		||||
            learning_rate=self.lr,
 | 
			
		||||
        learning_rate = lr.PolynomialDecay(
 | 
			
		||||
            learning_rate=self.learning_rate,
 | 
			
		||||
            decay_steps=self.epochs,
 | 
			
		||||
            end_lr=self.end_lr,
 | 
			
		||||
            power=self.power,
 | 
			
		||||
            last_epoch=self.last_epoch)
 | 
			
		||||
        if self.warmup_epoch > 0:
 | 
			
		||||
            learning_rate = lr_scheduler.LinearLrWarmup(
 | 
			
		||||
            learning_rate = lr.LinearWarmup(
 | 
			
		||||
                learning_rate=learning_rate,
 | 
			
		||||
                warmup_steps=self.warmup_epoch,
 | 
			
		||||
                start_lr=0.0,
 | 
			
		||||
                end_lr=self.lr,
 | 
			
		||||
                end_lr=self.learning_rate,
 | 
			
		||||
                last_epoch=self.last_epoch)
 | 
			
		||||
        return learning_rate
 | 
			
		||||
 | 
			
		||||
@ -77,27 +77,29 @@ class Cosine(object):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 lr,
 | 
			
		||||
                 learning_rate,
 | 
			
		||||
                 step_each_epoch,
 | 
			
		||||
                 epochs,
 | 
			
		||||
                 warmup_epoch=0,
 | 
			
		||||
                 last_epoch=-1,
 | 
			
		||||
                 **kwargs):
 | 
			
		||||
        super(Cosine, self).__init__()
 | 
			
		||||
        self.lr = lr
 | 
			
		||||
        self.learning_rate = learning_rate
 | 
			
		||||
        self.T_max = step_each_epoch * epochs
 | 
			
		||||
        self.last_epoch = last_epoch
 | 
			
		||||
        self.warmup_epoch = warmup_epoch * step_each_epoch
 | 
			
		||||
 | 
			
		||||
    def __call__(self):
 | 
			
		||||
        learning_rate = lr_scheduler.CosineAnnealingLR(
 | 
			
		||||
            learning_rate=self.lr, T_max=self.T_max, last_epoch=self.last_epoch)
 | 
			
		||||
        learning_rate = lr.CosineAnnealingDecay(
 | 
			
		||||
            learning_rate=self.learning_rate,
 | 
			
		||||
            T_max=self.T_max,
 | 
			
		||||
            last_epoch=self.last_epoch)
 | 
			
		||||
        if self.warmup_epoch > 0:
 | 
			
		||||
            learning_rate = lr_scheduler.LinearLrWarmup(
 | 
			
		||||
            learning_rate = lr.LinearWarmup(
 | 
			
		||||
                learning_rate=learning_rate,
 | 
			
		||||
                warmup_steps=self.warmup_epoch,
 | 
			
		||||
                start_lr=0.0,
 | 
			
		||||
                end_lr=self.lr,
 | 
			
		||||
                end_lr=self.learning_rate,
 | 
			
		||||
                last_epoch=self.last_epoch)
 | 
			
		||||
        return learning_rate
 | 
			
		||||
 | 
			
		||||
@ -115,7 +117,7 @@ class Step(object):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 lr,
 | 
			
		||||
                 learning_rate,
 | 
			
		||||
                 step_size,
 | 
			
		||||
                 step_each_epoch,
 | 
			
		||||
                 gamma,
 | 
			
		||||
@ -124,23 +126,23 @@ class Step(object):
 | 
			
		||||
                 **kwargs):
 | 
			
		||||
        super(Step, self).__init__()
 | 
			
		||||
        self.step_size = step_each_epoch * step_size
 | 
			
		||||
        self.lr = lr
 | 
			
		||||
        self.learning_rate = learning_rate
 | 
			
		||||
        self.gamma = gamma
 | 
			
		||||
        self.last_epoch = last_epoch
 | 
			
		||||
        self.warmup_epoch = warmup_epoch * step_each_epoch
 | 
			
		||||
 | 
			
		||||
    def __call__(self):
 | 
			
		||||
        learning_rate = lr_scheduler.StepLR(
 | 
			
		||||
            learning_rate=self.lr,
 | 
			
		||||
        learning_rate = lr.StepDecay(
 | 
			
		||||
            learning_rate=self.learning_rate,
 | 
			
		||||
            step_size=self.step_size,
 | 
			
		||||
            gamma=self.gamma,
 | 
			
		||||
            last_epoch=self.last_epoch)
 | 
			
		||||
        if self.warmup_epoch > 0:
 | 
			
		||||
            learning_rate = lr_scheduler.LinearLrWarmup(
 | 
			
		||||
            learning_rate = lr.LinearWarmup(
 | 
			
		||||
                learning_rate=learning_rate,
 | 
			
		||||
                warmup_steps=self.warmup_epoch,
 | 
			
		||||
                start_lr=0.0,
 | 
			
		||||
                end_lr=self.lr,
 | 
			
		||||
                end_lr=self.learning_rate,
 | 
			
		||||
                last_epoch=self.last_epoch)
 | 
			
		||||
        return learning_rate
 | 
			
		||||
 | 
			
		||||
@ -169,12 +171,12 @@ class Piecewise(object):
 | 
			
		||||
        self.warmup_epoch = warmup_epoch * step_each_epoch
 | 
			
		||||
 | 
			
		||||
    def __call__(self):
 | 
			
		||||
        learning_rate = lr_scheduler.PiecewiseLR(
 | 
			
		||||
        learning_rate = lr.PiecewiseDecay(
 | 
			
		||||
            boundaries=self.boundaries,
 | 
			
		||||
            values=self.values,
 | 
			
		||||
            last_epoch=self.last_epoch)
 | 
			
		||||
        if self.warmup_epoch > 0:
 | 
			
		||||
            learning_rate = lr_scheduler.LinearLrWarmup(
 | 
			
		||||
            learning_rate = lr.LinearWarmup(
 | 
			
		||||
                learning_rate=learning_rate,
 | 
			
		||||
                warmup_steps=self.warmup_epoch,
 | 
			
		||||
                start_lr=0.0,
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ logger_initialized = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@functools.lru_cache()
 | 
			
		||||
def get_logger(name='ppocr', log_file=None, log_level=logging.INFO):
 | 
			
		||||
def get_logger(name='root', log_file=None, log_level=logging.INFO):
 | 
			
		||||
    """Initialize and get a logger by name.
 | 
			
		||||
    If the logger has not been initialized, this method will initialize the
 | 
			
		||||
    logger by adding one or two handlers, otherwise the initialized logger will
 | 
			
		||||
 | 
			
		||||
@ -152,7 +152,6 @@ def train(config,
 | 
			
		||||
          pre_best_model_dict,
 | 
			
		||||
          logger,
 | 
			
		||||
          vdl_writer=None):
 | 
			
		||||
 | 
			
		||||
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
 | 
			
		||||
                                                   False)
 | 
			
		||||
    log_smooth_window = config['Global']['log_smooth_window']
 | 
			
		||||
@ -185,14 +184,13 @@ def train(config,
 | 
			
		||||
 | 
			
		||||
    for epoch in range(start_epoch, epoch_num):
 | 
			
		||||
        if epoch > 0:
 | 
			
		||||
            train_loader = build_dataloader(config, 'Train', device)
 | 
			
		||||
            train_dataloader = build_dataloader(config, 'Train', device, logger)
 | 
			
		||||
 | 
			
		||||
        for idx, batch in enumerate(train_dataloader):
 | 
			
		||||
            if idx >= len(train_dataloader):
 | 
			
		||||
                break
 | 
			
		||||
            lr = optimizer.get_lr()
 | 
			
		||||
            t1 = time.time()
 | 
			
		||||
            batch = [paddle.to_tensor(x) for x in batch]
 | 
			
		||||
            images = batch[0]
 | 
			
		||||
            preds = model(images)
 | 
			
		||||
            loss = loss_class(preds, batch)
 | 
			
		||||
@ -301,11 +299,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
 | 
			
		||||
    with paddle.no_grad():
 | 
			
		||||
        total_frame = 0.0
 | 
			
		||||
        total_time = 0.0
 | 
			
		||||
        #         pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
 | 
			
		||||
        pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
 | 
			
		||||
        for idx, batch in enumerate(valid_dataloader):
 | 
			
		||||
            if idx >= len(valid_dataloader):
 | 
			
		||||
                break
 | 
			
		||||
            images = paddle.to_tensor(batch[0])
 | 
			
		||||
            images = batch[0]
 | 
			
		||||
            start = time.time()
 | 
			
		||||
            preds = model(images)
 | 
			
		||||
 | 
			
		||||
@ -315,15 +313,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
 | 
			
		||||
            total_time += time.time() - start
 | 
			
		||||
            # Evaluate the results of the current batch
 | 
			
		||||
            eval_class(post_result, batch)
 | 
			
		||||
            #             pbar.update(1)
 | 
			
		||||
            pbar.update(1)
 | 
			
		||||
            total_frame += len(images)
 | 
			
		||||
            if idx % print_batch_step == 0 and dist.get_rank() == 0:
 | 
			
		||||
                logger.info('tackling images for eval: {}/{}'.format(
 | 
			
		||||
                    idx, len(valid_dataloader)))
 | 
			
		||||
            # if idx % print_batch_step == 0 and dist.get_rank() == 0:
 | 
			
		||||
            #     logger.info('tackling images for eval: {}/{}'.format(
 | 
			
		||||
            #         idx, len(valid_dataloader)))
 | 
			
		||||
        # Get final metirc,eg. acc or hmean
 | 
			
		||||
        metirc = eval_class.get_metric()
 | 
			
		||||
 | 
			
		||||
#         pbar.close()
 | 
			
		||||
    pbar.close()
 | 
			
		||||
    model.train()
 | 
			
		||||
    metirc['fps'] = total_frame / total_time
 | 
			
		||||
    return metirc
 | 
			
		||||
@ -354,7 +352,8 @@ def preprocess():
 | 
			
		||||
    with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
 | 
			
		||||
        yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
 | 
			
		||||
 | 
			
		||||
    logger = get_logger(log_file='{}/train.log'.format(save_model_dir))
 | 
			
		||||
    logger = get_logger(
 | 
			
		||||
        name='root', log_file='{}/train.log'.format(save_model_dir))
 | 
			
		||||
    if config['Global']['use_visualdl']:
 | 
			
		||||
        from visualdl import LogWriter
 | 
			
		||||
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
 | 
			
		||||
 | 
			
		||||
@ -36,7 +36,6 @@ from ppocr.optimizer import build_optimizer
 | 
			
		||||
from ppocr.postprocess import build_post_process
 | 
			
		||||
from ppocr.metrics import build_metric
 | 
			
		||||
from ppocr.utils.save_load import init_model
 | 
			
		||||
from ppocr.utils.utility import print_dict
 | 
			
		||||
import tools.program as program
 | 
			
		||||
 | 
			
		||||
dist.get_world_size()
 | 
			
		||||
@ -81,10 +80,11 @@ def main(config, device, logger, vdl_writer):
 | 
			
		||||
 | 
			
		||||
    # build metric
 | 
			
		||||
    eval_class = build_metric(config['Metric'])
 | 
			
		||||
 | 
			
		||||
    # load pretrain model
 | 
			
		||||
    pre_best_model_dict = init_model(config, model, logger, optimizer)
 | 
			
		||||
 | 
			
		||||
    logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
 | 
			
		||||
                format(len(train_dataloader), len(valid_dataloader)))
 | 
			
		||||
    # start train
 | 
			
		||||
    program.train(config, train_dataloader, valid_dataloader, device, model,
 | 
			
		||||
                  loss_class, optimizer, lr_scheduler, post_process_class,
 | 
			
		||||
@ -92,8 +92,7 @@ def main(config, device, logger, vdl_writer):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_reader(config, device, logger):
 | 
			
		||||
    loader = build_dataloader(config, 'Train', device)
 | 
			
		||||
    #     loader = build_dataloader(config, 'Eval', device)
 | 
			
		||||
    loader = build_dataloader(config, 'Train', device, logger)
 | 
			
		||||
    import time
 | 
			
		||||
    starttime = time.time()
 | 
			
		||||
    count = 0
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user